mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(lerobot_train.py): add missing sampling weight script
This commit is contained in:
@@ -0,0 +1,213 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample weighting abstraction for training.
|
||||
|
||||
This module provides a generic protocol for sample weighting strategies (e.g., RA-BC)
|
||||
that can be used during training without polluting the training script with
|
||||
policy-specific code.
|
||||
|
||||
Example usage:
|
||||
# In training config
|
||||
sample_weighting:
|
||||
type: rabc
|
||||
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
|
||||
head_mode: sparse
|
||||
kappa: 0.01
|
||||
|
||||
# In training script
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
|
||||
...
|
||||
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SampleWeighter(Protocol):
|
||||
"""
|
||||
Protocol for sample weighting strategies during training.
|
||||
|
||||
Implementations compute per-sample weights that can be used to weight
|
||||
the loss during training. This enables techniques like:
|
||||
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||
- Importance sampling
|
||||
- Curriculum learning
|
||||
- Quality-based filtering
|
||||
"""
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute per-sample weights for a training batch.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary containing at minimum an "index" key
|
||||
with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- weights: Tensor of shape (batch_size,) with per-sample weights,
|
||||
normalized to sum to batch_size for stable gradients.
|
||||
- stats: Dictionary with logging-friendly statistics about the weights.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get global statistics about the weighting strategy.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics for logging (e.g., mean delta, coverage).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleWeightingConfig:
|
||||
"""
|
||||
Configuration for sample weighting during training.
|
||||
|
||||
This is a generic config that supports multiple weighting strategies.
|
||||
The `type` field determines which implementation to use, and `params`
|
||||
contains type-specific parameters.
|
||||
|
||||
Attributes:
|
||||
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||
progress_path: Path to precomputed progress values (for RABC)
|
||||
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||
epsilon: Small constant for numerical stability
|
||||
"""
|
||||
|
||||
type: str = "rabc"
|
||||
progress_path: str | None = None
|
||||
head_mode: str = "sparse"
|
||||
kappa: float = 0.01
|
||||
epsilon: float = 1e-6
|
||||
# Additional type-specific params can be added here or passed via extra_params
|
||||
extra_params: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def make_sample_weighter(
|
||||
config: SampleWeightingConfig | None,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
) -> SampleWeighter | None:
|
||||
"""
|
||||
Factory function to create a SampleWeighter from config.
|
||||
|
||||
This keeps policy-specific initialization logic out of the training script.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration, or None to disable weighting.
|
||||
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||
device: Device to place weight tensors on.
|
||||
|
||||
Returns:
|
||||
SampleWeighter instance, or None if config is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the weighting type is unknown or required params are missing.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if config.type == "rabc":
|
||||
return _make_rabc_weighter(config, policy, device)
|
||||
|
||||
if config.type == "uniform":
|
||||
# No-op weighter that returns uniform weights
|
||||
return UniformWeighter(device=device)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown sample weighting type: '{config.type}'. "
|
||||
f"Supported types: 'rabc', 'uniform'"
|
||||
)
|
||||
|
||||
|
||||
def _make_rabc_weighter(
|
||||
config: SampleWeightingConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
) -> SampleWeighter:
|
||||
"""Create RABC weighter with policy-specific initialization."""
|
||||
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
# Extract chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
|
||||
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||
)
|
||||
|
||||
if config.progress_path is None:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires 'progress_path' to be set. "
|
||||
"Generate progress values using: "
|
||||
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||
)
|
||||
|
||||
return RABCWeights(
|
||||
progress_path=config.progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=config.head_mode,
|
||||
kappa=config.kappa,
|
||||
epsilon=config.epsilon,
|
||||
device=device,
|
||||
**config.extra_params,
|
||||
)
|
||||
|
||||
|
||||
class UniformWeighter:
|
||||
"""
|
||||
No-op sample weighter that returns uniform weights.
|
||||
|
||||
Useful as a baseline or when you want to disable weighting without
|
||||
changing the training code structure.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""Return uniform weights (all ones)."""
|
||||
# Determine batch size from batch
|
||||
batch_size = 1
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
batch_size = val.shape[0]
|
||||
break
|
||||
|
||||
weights = torch.ones(batch_size, device=self.device)
|
||||
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||
return weights, stats
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return empty stats for uniform weighting."""
|
||||
return {"type": "uniform"}
|
||||
|
||||
Reference in New Issue
Block a user