mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
refactor(lerobot_train.py): add missing sampling weight script
This commit is contained in:
@@ -0,0 +1,213 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Sample weighting abstraction for training.
|
||||||
|
|
||||||
|
This module provides a generic protocol for sample weighting strategies (e.g., RA-BC)
|
||||||
|
that can be used during training without polluting the training script with
|
||||||
|
policy-specific code.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# In training config
|
||||||
|
sample_weighting:
|
||||||
|
type: rabc
|
||||||
|
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
|
||||||
|
head_mode: sparse
|
||||||
|
kappa: 0.01
|
||||||
|
|
||||||
|
# In training script
|
||||||
|
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
|
||||||
|
...
|
||||||
|
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SampleWeighter(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for sample weighting strategies during training.
|
||||||
|
|
||||||
|
Implementations compute per-sample weights that can be used to weight
|
||||||
|
the loss during training. This enables techniques like:
|
||||||
|
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||||
|
- Importance sampling
|
||||||
|
- Curriculum learning
|
||||||
|
- Quality-based filtering
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Compute per-sample weights for a training batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch dictionary containing at minimum an "index" key
|
||||||
|
with global frame indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of:
|
||||||
|
- weights: Tensor of shape (batch_size,) with per-sample weights,
|
||||||
|
normalized to sum to batch_size for stable gradients.
|
||||||
|
- stats: Dictionary with logging-friendly statistics about the weights.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get global statistics about the weighting strategy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with statistics for logging (e.g., mean delta, coverage).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SampleWeightingConfig:
|
||||||
|
"""
|
||||||
|
Configuration for sample weighting during training.
|
||||||
|
|
||||||
|
This is a generic config that supports multiple weighting strategies.
|
||||||
|
The `type` field determines which implementation to use, and `params`
|
||||||
|
contains type-specific parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||||
|
progress_path: Path to precomputed progress values (for RABC)
|
||||||
|
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||||
|
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||||
|
epsilon: Small constant for numerical stability
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "rabc"
|
||||||
|
progress_path: str | None = None
|
||||||
|
head_mode: str = "sparse"
|
||||||
|
kappa: float = 0.01
|
||||||
|
epsilon: float = 1e-6
|
||||||
|
# Additional type-specific params can be added here or passed via extra_params
|
||||||
|
extra_params: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_sample_weighter(
|
||||||
|
config: SampleWeightingConfig | None,
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
device: torch.device,
|
||||||
|
) -> SampleWeighter | None:
|
||||||
|
"""
|
||||||
|
Factory function to create a SampleWeighter from config.
|
||||||
|
|
||||||
|
This keeps policy-specific initialization logic out of the training script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Sample weighting configuration, or None to disable weighting.
|
||||||
|
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||||
|
device: Device to place weight tensors on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SampleWeighter instance, or None if config is None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the weighting type is unknown or required params are missing.
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if config.type == "rabc":
|
||||||
|
return _make_rabc_weighter(config, policy, device)
|
||||||
|
|
||||||
|
if config.type == "uniform":
|
||||||
|
# No-op weighter that returns uniform weights
|
||||||
|
return UniformWeighter(device=device)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown sample weighting type: '{config.type}'. "
|
||||||
|
f"Supported types: 'rabc', 'uniform'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rabc_weighter(
|
||||||
|
config: SampleWeightingConfig,
|
||||||
|
policy: PreTrainedPolicy,
|
||||||
|
device: torch.device,
|
||||||
|
) -> SampleWeighter:
|
||||||
|
"""Create RABC weighter with policy-specific initialization."""
|
||||||
|
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||||
|
from lerobot.policies.sarm.rabc import RABCWeights
|
||||||
|
|
||||||
|
# Extract chunk_size from policy config
|
||||||
|
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||||
|
if chunk_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
|
||||||
|
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.progress_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RABC sample weighting requires 'progress_path' to be set. "
|
||||||
|
"Generate progress values using: "
|
||||||
|
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RABCWeights(
|
||||||
|
progress_path=config.progress_path,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
head_mode=config.head_mode,
|
||||||
|
kappa=config.kappa,
|
||||||
|
epsilon=config.epsilon,
|
||||||
|
device=device,
|
||||||
|
**config.extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UniformWeighter:
|
||||||
|
"""
|
||||||
|
No-op sample weighter that returns uniform weights.
|
||||||
|
|
||||||
|
Useful as a baseline or when you want to disable weighting without
|
||||||
|
changing the training code structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""Return uniform weights (all ones)."""
|
||||||
|
# Determine batch size from batch
|
||||||
|
batch_size = 1
|
||||||
|
for key in ["action", "index"]:
|
||||||
|
if key in batch:
|
||||||
|
val = batch[key]
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
batch_size = val.shape[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
weights = torch.ones(batch_size, device=self.device)
|
||||||
|
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||||
|
return weights, stats
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Return empty stats for uniform weighting."""
|
||||||
|
return {"type": "uniform"}
|
||||||
|
|
||||||
Reference in New Issue
Block a user