refactor(lerobot_train.py): add missing sampling weight script

This commit is contained in:
Michel Aractingi
2026-01-12 11:38:21 +01:00
parent eb2e79e22d
commit ba690632d9
+213
View File
@@ -0,0 +1,213 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample weighting abstraction for training.
This module provides a generic protocol for sample weighting strategies (e.g., RA-BC)
that can be used during training without polluting the training script with
policy-specific code.
Example usage:
# In training config
sample_weighting:
type: rabc
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
head_mode: sparse
kappa: 0.01
# In training script
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
...
weights, stats = sample_weighter.compute_batch_weights(batch)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, runtime_checkable
import torch
if TYPE_CHECKING:
from lerobot.policies.pretrained import PreTrainedPolicy
@runtime_checkable
class SampleWeighter(Protocol):
"""
Protocol for sample weighting strategies during training.
Implementations compute per-sample weights that can be used to weight
the loss during training. This enables techniques like:
- RA-BC (Reward-Aligned Behavior Cloning)
- Importance sampling
- Curriculum learning
- Quality-based filtering
"""
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute per-sample weights for a training batch.
Args:
batch: Training batch dictionary containing at minimum an "index" key
with global frame indices.
Returns:
Tuple of:
- weights: Tensor of shape (batch_size,) with per-sample weights,
normalized to sum to batch_size for stable gradients.
- stats: Dictionary with logging-friendly statistics about the weights.
"""
...
def get_stats(self) -> dict:
"""
Get global statistics about the weighting strategy.
Returns:
Dictionary with statistics for logging (e.g., mean delta, coverage).
"""
...
@dataclass
class SampleWeightingConfig:
"""
Configuration for sample weighting during training.
This is a generic config that supports multiple weighting strategies.
The `type` field determines which implementation to use, and `params`
contains type-specific parameters.
Attributes:
type: Weighting strategy type ("rabc", "uniform", etc.)
progress_path: Path to precomputed progress values (for RABC)
head_mode: Which model head to use for progress ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (RABC-specific)
epsilon: Small constant for numerical stability
"""
type: str = "rabc"
progress_path: str | None = None
head_mode: str = "sparse"
kappa: float = 0.01
epsilon: float = 1e-6
# Additional type-specific params can be added here or passed via extra_params
extra_params: dict = field(default_factory=dict)
def make_sample_weighter(
config: SampleWeightingConfig | None,
policy: PreTrainedPolicy,
device: torch.device,
) -> SampleWeighter | None:
"""
Factory function to create a SampleWeighter from config.
This keeps policy-specific initialization logic out of the training script.
Args:
config: Sample weighting configuration, or None to disable weighting.
policy: The policy being trained (used to extract chunk_size, etc.)
device: Device to place weight tensors on.
Returns:
SampleWeighter instance, or None if config is None.
Raises:
ValueError: If the weighting type is unknown or required params are missing.
"""
if config is None:
return None
if config.type == "rabc":
return _make_rabc_weighter(config, policy, device)
if config.type == "uniform":
# No-op weighter that returns uniform weights
return UniformWeighter(device=device)
raise ValueError(
f"Unknown sample weighting type: '{config.type}'. "
f"Supported types: 'rabc', 'uniform'"
)
def _make_rabc_weighter(
config: SampleWeightingConfig,
policy: PreTrainedPolicy,
device: torch.device,
) -> SampleWeighter:
"""Create RABC weighter with policy-specific initialization."""
# Import here to avoid circular imports and keep RABC code in SARM module
from lerobot.policies.sarm.rabc import RABCWeights
# Extract chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError(
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
)
if config.progress_path is None:
raise ValueError(
"RABC sample weighting requires 'progress_path' to be set. "
"Generate progress values using: "
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
)
return RABCWeights(
progress_path=config.progress_path,
chunk_size=chunk_size,
head_mode=config.head_mode,
kappa=config.kappa,
epsilon=config.epsilon,
device=device,
**config.extra_params,
)
class UniformWeighter:
"""
No-op sample weighter that returns uniform weights.
Useful as a baseline or when you want to disable weighting without
changing the training code structure.
"""
def __init__(self, device: torch.device):
self.device = device
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""Return uniform weights (all ones)."""
# Determine batch size from batch
batch_size = 1
for key in ["action", "index"]:
if key in batch:
val = batch[key]
if isinstance(val, torch.Tensor):
batch_size = val.shape[0]
break
weights = torch.ones(batch_size, device=self.device)
stats = {"mean_weight": 1.0, "type": "uniform"}
return weights, stats
def get_stats(self) -> dict:
"""Return empty stats for uniform weighting."""
return {"type": "uniform"}