From ba690632d955f8cb8c72aa259450dc33fd7de0c1 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 12 Jan 2026 11:38:21 +0100 Subject: [PATCH] refactor(lerobot_train.py): add missing sampling weight script --- src/lerobot/utils/sample_weighting.py | 213 ++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 src/lerobot/utils/sample_weighting.py diff --git a/src/lerobot/utils/sample_weighting.py b/src/lerobot/utils/sample_weighting.py new file mode 100644 index 000000000..ec3b5ab07 --- /dev/null +++ b/src/lerobot/utils/sample_weighting.py @@ -0,0 +1,213 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample weighting abstraction for training. + +This module provides a generic protocol for sample weighting strategies (e.g., RA-BC) +that can be used during training without polluting the training script with +policy-specific code. + +Example usage: + # In training config + sample_weighting: + type: rabc + progress_path: hf://datasets/my-dataset/sarm_progress.parquet + head_mode: sparse + kappa: 0.01 + + # In training script + sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device) + ... + weights, stats = sample_weighter.compute_batch_weights(batch) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +import torch + +if TYPE_CHECKING: + from lerobot.policies.pretrained import PreTrainedPolicy + + +@runtime_checkable +class SampleWeighter(Protocol): + """ + Protocol for sample weighting strategies during training. + + Implementations compute per-sample weights that can be used to weight + the loss during training. This enables techniques like: + - RA-BC (Reward-Aligned Behavior Cloning) + - Importance sampling + - Curriculum learning + - Quality-based filtering + """ + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """ + Compute per-sample weights for a training batch. + + Args: + batch: Training batch dictionary containing at minimum an "index" key + with global frame indices. + + Returns: + Tuple of: + - weights: Tensor of shape (batch_size,) with per-sample weights, + normalized to sum to batch_size for stable gradients. + - stats: Dictionary with logging-friendly statistics about the weights. + """ + ... + + def get_stats(self) -> dict: + """ + Get global statistics about the weighting strategy. + + Returns: + Dictionary with statistics for logging (e.g., mean delta, coverage). + """ + ... + + +@dataclass +class SampleWeightingConfig: + """ + Configuration for sample weighting during training. + + This is a generic config that supports multiple weighting strategies. + The `type` field determines which implementation to use, and `params` + contains type-specific parameters. + + Attributes: + type: Weighting strategy type ("rabc", "uniform", etc.) + progress_path: Path to precomputed progress values (for RABC) + head_mode: Which model head to use for progress ("sparse" or "dense") + kappa: Hard threshold for high-quality samples (RABC-specific) + epsilon: Small constant for numerical stability + """ + + type: str = "rabc" + progress_path: str | None = None + head_mode: str = "sparse" + kappa: float = 0.01 + epsilon: float = 1e-6 + # Additional type-specific params can be added here or passed via extra_params + extra_params: dict = field(default_factory=dict) + + +def make_sample_weighter( + config: SampleWeightingConfig | None, + policy: PreTrainedPolicy, + device: torch.device, +) -> SampleWeighter | None: + """ + Factory function to create a SampleWeighter from config. + + This keeps policy-specific initialization logic out of the training script. + + Args: + config: Sample weighting configuration, or None to disable weighting. + policy: The policy being trained (used to extract chunk_size, etc.) + device: Device to place weight tensors on. + + Returns: + SampleWeighter instance, or None if config is None. + + Raises: + ValueError: If the weighting type is unknown or required params are missing. + """ + if config is None: + return None + + if config.type == "rabc": + return _make_rabc_weighter(config, policy, device) + + if config.type == "uniform": + # No-op weighter that returns uniform weights + return UniformWeighter(device=device) + + raise ValueError( + f"Unknown sample weighting type: '{config.type}'. " + f"Supported types: 'rabc', 'uniform'" + ) + + +def _make_rabc_weighter( + config: SampleWeightingConfig, + policy: PreTrainedPolicy, + device: torch.device, +) -> SampleWeighter: + """Create RABC weighter with policy-specific initialization.""" + # Import here to avoid circular imports and keep RABC code in SARM module + from lerobot.policies.sarm.rabc import RABCWeights + + # Extract chunk_size from policy config + chunk_size = getattr(policy.config, "chunk_size", None) + if chunk_size is None: + raise ValueError( + "RABC sample weighting requires a policy with 'chunk_size' in its config. " + "This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc." + ) + + if config.progress_path is None: + raise ValueError( + "RABC sample weighting requires 'progress_path' to be set. " + "Generate progress values using: " + "python -m lerobot.policies.sarm.compute_rabc_weights --help" + ) + + return RABCWeights( + progress_path=config.progress_path, + chunk_size=chunk_size, + head_mode=config.head_mode, + kappa=config.kappa, + epsilon=config.epsilon, + device=device, + **config.extra_params, + ) + + +class UniformWeighter: + """ + No-op sample weighter that returns uniform weights. + + Useful as a baseline or when you want to disable weighting without + changing the training code structure. + """ + + def __init__(self, device: torch.device): + self.device = device + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """Return uniform weights (all ones).""" + # Determine batch size from batch + batch_size = 1 + for key in ["action", "index"]: + if key in batch: + val = batch[key] + if isinstance(val, torch.Tensor): + batch_size = val.shape[0] + break + + weights = torch.ones(batch_size, device=self.device) + stats = {"mean_weight": 1.0, "type": "uniform"} + return weights, stats + + def get_stats(self) -> dict: + """Return empty stats for uniform weighting.""" + return {"type": "uniform"} +