Compare commits

...

15 Commits

Author SHA1 Message Date
Michel Aractingi c4ff7dbf52 Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-26 16:20:55 +01:00
Michel Aractingi 2a6e6bef2b Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-19 16:56:21 +01:00
Michel Aractingi bbd1f1f920 Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-19 16:39:07 +01:00
Michel Aractingi 5872b04851 Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-14 17:33:10 +01:00
Michel Aractingi 80b0f1aaa2 improve comment 2026-01-14 17:13:26 +01:00
Michel Aractingi 0264ac717b remove type exp 2026-01-14 17:09:56 +01:00
Michel Aractingi 94efcea867 add automatic detection of the progress path 2026-01-14 17:08:23 +01:00
Michel Aractingi faa276b8cf update docs 2026-01-14 16:38:26 +01:00
Michel Aractingi fd7cb9a5c5 Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-14 16:09:26 +01:00
Michel Aractingi 25e3119a33 Merge branch 'main' into refactor/lerobot_train_rabc 2026-01-12 17:07:31 +01:00
Michel Aractingi f49b8537ad revert some useless changes, improve typing 2026-01-12 14:29:34 +01:00
Michel Aractingi ded91ca866 add testing for sampl weighter 2026-01-12 14:10:03 +01:00
Michel Aractingi 9ebc144b30 linter + missing files 2026-01-12 11:39:01 +01:00
Michel Aractingi ba690632d9 refactor(lerobot_train.py): add missing sampling weight script 2026-01-12 11:38:21 +01:00
Michel Aractingi eb2e79e22d refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train 2026-01-12 11:38:04 +01:00
6 changed files with 757 additions and 118 deletions
+24 -23
View File
@@ -465,15 +465,15 @@ This script:
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_head_mode=sparse \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.head_mode=sparse \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -488,12 +488,13 @@ The training script automatically:
**RA-BC Arguments:**
| Argument | Description | Default |
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
| Argument | Description | Default |
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------ | ------------- | ------------------------- |
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `rabc_delta_mean` | > 0 | Should be positive |
| `rabc_delta_std` | > 0 | Variance in data quality |
| Metric | Healthy Range | Problem Indicator |
| ----------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
**Setting kappa based on your data:**
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
```
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
# Most deltas fall in range [0.01, 0.05]
# Option 1: Set kappa = delta_mean (medium selectivity)
--rabc_kappa=0.03
--sample_weighting.kappa=0.03
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
--rabc_kappa=0.05
--sample_weighting.kappa=0.05
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
--rabc_kappa=0.07
--sample_weighting.kappa=0.07
```
**When RA-BC may not help:**
@@ -550,8 +551,8 @@ accelerate launch \
src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--use_rabc=true \
--rabc_kappa=0.01 \
--sample_weighting.type=rabc \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
--steps=40000
@@ -576,7 +577,7 @@ accelerate launch \
### RA-BC
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
---
+3 -14
View File
@@ -29,6 +29,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
TRAIN_CONFIG_NAME = "train_config.json"
@@ -67,12 +68,8 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
@@ -140,14 +137,6 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
if self.use_rabc and not self.rabc_progress_path:
# Auto-detect from dataset path
repo_id = self.dataset.repo_id
if self.dataset.root:
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
else:
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
This module implements the SampleWeighter protocol for RA-BC training,
which weights training samples based on their task progress as measured
by the SARM reward model.
The weights are computed based on progress deltas:
delta = progress[t + chunk_size] - progress[t]
High-quality samples (positive progress) get higher weights, while
samples with negative progress (going backwards) get zero weight.
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
"""
import logging
from pathlib import Path
@@ -22,6 +36,8 @@ import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from lerobot.utils.sample_weighting import SampleWeighter
def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
@@ -34,23 +50,27 @@ def resolve_hf_path(path: str | Path) -> Path:
return Path(path)
class RABCWeights:
class RABCWeights(SampleWeighter):
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
This class implements the SampleWeighter ABC for use with the generic
sample weighting infrastructure in lerobot.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9)
Args:
progress_path: Path to parquet file with precomputed progress values
chunk_size: Number of frames ahead for computing progress delta
head_mode: Which SARM head to use ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
device: Device to return tensors on
progress_path: Path to parquet file with precomputed progress values.
Supports HuggingFace URLs (hf://datasets/...).
chunk_size: Number of frames ahead for computing progress delta.
head_mode: Which SARM head to use ("sparse" or "dense").
kappa: Hard threshold for high-quality samples (default: 0.01).
epsilon: Small constant for numerical stability (default: 1e-6).
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
device: Device to return tensors on.
"""
def __init__(
@@ -61,7 +81,7 @@ class RABCWeights:
kappa: float = 0.01,
epsilon: float = 1e-6,
fallback_weight: float = 1.0,
device: torch.device = None,
device: torch.device | None = None,
):
self.progress_path = resolve_hf_path(progress_path)
self.chunk_size = chunk_size
@@ -87,8 +107,8 @@ class RABCWeights:
logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup = {}
self.episode_lookup = {}
self.progress_lookup: dict[int, float] = {}
self.episode_lookup: dict[int, int] = {}
for _, row in self.df.iterrows():
global_idx = int(row["index"])
@@ -100,7 +120,7 @@ class RABCWeights:
self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation
self.episode_boundaries = {}
self.episode_boundaries: dict[int, dict[str, int]] = {}
for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = {
@@ -114,7 +134,7 @@ class RABCWeights:
# Compute global statistics for weight computation
self._compute_global_stats()
def _compute_global_stats(self):
def _compute_global_stats(self) -> None:
"""Compute global mean and std of progress deltas for weight calculation."""
all_deltas = []
@@ -138,8 +158,8 @@ class RABCWeights:
all_deltas.append(delta)
if all_deltas:
self.delta_mean = max(np.mean(all_deltas), 0.0)
self.delta_std = max(np.std(all_deltas), self.epsilon)
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else:
self.delta_mean = 0.0
@@ -157,18 +177,19 @@ class RABCWeights:
4. Compute weight using paper Eq. 8-9
Args:
batch: Training batch containing "index" key with global frame indices
batch: Training batch containing "index" key with global frame indices.
Returns:
Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
- Weights tensor (batch_size,) normalized to sum to batch_size.
- Stats dict with weighting statistics for logging.
"""
indices = batch.get("index")
if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch)
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
return torch.ones(batch_size, device=self.device), stats
# Convert to list of ints
if isinstance(indices, torch.Tensor):
@@ -183,29 +204,29 @@ class RABCWeights:
delta = self._compute_delta(idx)
deltas.append(delta)
deltas = np.array(deltas, dtype=np.float32)
deltas_array = np.array(deltas, dtype=np.float32)
# Compute weights from deltas
weights = self._compute_weights(deltas)
weights = self._compute_weights(deltas_array)
# Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0))
batch_stats = {
"raw_mean_weight": raw_mean_weight,
"mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight,
}
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size
batch_size = len(weights)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
batch_size = len(weights_tensor)
weight_sum = weights_tensor.sum() + self.epsilon
weights_tensor = weights_tensor * batch_size / weight_sum
return weights, batch_stats
return weights_tensor, batch_stats
def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame."""
@@ -241,7 +262,7 @@ class RABCWeights:
- Final weight: wi = 1{ri > κ} + 1{0 ri κ}˜wi
Returns:
Array of weights
Array of weights.
"""
valid_mask = ~np.isnan(deltas)
@@ -273,12 +294,13 @@ class RABCWeights:
if key in batch:
val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)):
return val.shape[0]
return int(val.shape[0])
return 1
def get_stats(self) -> dict:
"""Get statistics."""
"""Get global statistics about the RA-BC weighting."""
return {
"type": "rabc",
"num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size,
"head_mode": self.head_mode,
+39 -49
View File
@@ -63,8 +63,8 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
) -> tuple[MetricsTracker, dict]:
sample_weighter=None,
) -> tuple[MetricsTracker, dict | None]:
"""
Performs a single training step to update the policy's weights.
@@ -80,7 +80,7 @@ def update_policy(
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
lock: An optional lock for thread-safe optimizer updates.
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
Returns:
A tuple containing:
@@ -90,27 +90,31 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
# Get RA-BC weights if enabled
rabc_batch_weights = None
rabc_batch_stats = None
if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
if sample_weighter is not None:
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
# Get per-sample losses
if sample_weights is not None:
# Use per-sample loss for weighted training
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
# rabc_batch_weights is already normalized to sum to batch_size
# Weighted loss: each sample's contribution is scaled by its weight.
# We divide by weight sum (not batch size) so that if some weights are zero,
# the remaining samples contribute proportionally more, preserving gradient scale.
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
epsilon = 1e-6
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
# Log raw mean weight (before normalization) - this is the meaningful metric
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
# Log weighting statistics
if output_dict is None:
output_dict = {}
for key, value in weight_stats.items():
output_dict[f"sample_weight_{key}"] = value
else:
loss, output_dict = policy.forward(batch)
@@ -288,27 +292,19 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
rabc_weights = None
if cfg.use_rabc:
from lerobot.utils.rabc import RABCWeights
# Create sample weighter if configured (e.g., for RA-BC training)
sample_weighter = None
if cfg.sample_weighting is not None:
from lerobot.utils.sample_weighting import make_sample_weighter
# Get chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError("Chunk size is not found in policy config")
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
rabc_weights = RABCWeights(
progress_path=cfg.rabc_progress_path,
chunk_size=chunk_size,
head_mode=head_mode,
kappa=getattr(cfg, "rabc_kappa", 0.01),
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
device=device,
if is_main_process:
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
sample_weighter = make_sample_weighter(
cfg.sample_weighting,
policy,
device,
dataset_root=cfg.dataset.root,
dataset_repo_id=cfg.dataset.repo_id,
)
step = 0 # number of policy updates (forward + backward + optim)
@@ -408,7 +404,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
sample_weighter=sample_weighter,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -425,16 +421,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
+239
View File
@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample weighting abstraction for training.
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
that can be used during training without polluting the training script with
policy-specific code.
Example usage:
# In training config
sample_weighting:
type: rabc
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
head_mode: sparse
kappa: 0.01
# In training script
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
...
weights, stats = sample_weighter.compute_batch_weights(batch)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from lerobot.policies.pretrained import PreTrainedPolicy
class SampleWeighter(ABC):
"""
Implementations compute per-sample weights that can be used to weight
the loss during training. This enables techniques like:
- RA-BC (Reward-Aligned Behavior Cloning)
- Importance sampling
- Curriculum learning
- Quality-based filtering
"""
@abstractmethod
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute per-sample weights for a training batch.
Args:
batch: Training batch dictionary containing at minimum an "index" key
with global frame indices.
"""
@abstractmethod
def get_stats(self) -> dict:
"""
Get global statistics about the weighting strategy.
"""
@dataclass
class SampleWeightingConfig:
"""
Configuration for sample weighting during training.
This is a generic config that supports multiple weighting strategies.
The `type` field determines which implementation to use, and `extra_params`
contains additional type-specific parameters.
Attributes:
type: Weighting strategy type ("rabc", "uniform", etc.)
progress_path: Path to precomputed progress values (for RABC)
head_mode: Which model head to use for progress ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (RABC-specific)
epsilon: Small constant for numerical stability
extra_params: Additional type-specific parameters passed to the weighter
"""
type: str = "rabc"
progress_path: str | None = None
head_mode: str = "sparse"
kappa: float = 0.01
epsilon: float = 1e-6
# Additional type-specific params can be added here or passed via extra_params
extra_params: dict = field(default_factory=dict)
def make_sample_weighter(
config: SampleWeightingConfig | None,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter | None:
"""
Factory function to create a SampleWeighter from config.
This keeps policy-specific initialization logic out of the training script.
Args:
config: Sample weighting configuration, or None to disable weighting.
policy: The policy being trained (used to extract chunk_size, etc.)
device: Device to place weight tensors on.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
if config is None:
return None
if config.type == "rabc":
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
if config.type == "uniform":
# No-op weighter that returns uniform weights
return UniformWeighter(device=device)
raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'")
def _make_rabc_weighter(
config: SampleWeightingConfig,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter:
"""Create RABC weighter with policy-specific initialization.
Args:
config: Sample weighting configuration.
policy: The policy being trained (used to extract chunk_size).
device: Device to place weight tensors on.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
# Import here to avoid circular imports and keep RABC code in SARM module
from lerobot.policies.sarm.rabc import RABCWeights
# Extract chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError(
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
)
# Determine progress_path: use explicit config or auto-detect from dataset
progress_path = config.progress_path
if progress_path is None:
if dataset_root:
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
elif dataset_repo_id:
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
else:
raise ValueError(
"RABC sample weighting requires 'progress_path' to be set, "
"or dataset_root/dataset_repo_id for auto-detection. "
"Generate progress values using: "
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
)
return RABCWeights(
progress_path=progress_path,
chunk_size=chunk_size,
head_mode=config.head_mode,
kappa=config.kappa,
epsilon=config.epsilon,
device=device,
**config.extra_params,
)
class UniformWeighter(SampleWeighter):
"""
No-op sample weighter that returns uniform weights.
Useful as a baseline or when you want to disable weighting without
changing the training code structure.
Note:
Batch size is determined by looking for tensor values in the batch
dictionary. The method checks common keys like "action", "index",
and "observation.state" first, then falls back to scanning all values.
"""
def __init__(self, device: torch.device):
self.device = device
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""Return uniform weights (all ones)."""
batch_size = self._determine_batch_size(batch)
weights = torch.ones(batch_size, device=self.device)
stats = {"mean_weight": 1.0, "type": "uniform"}
return weights, stats
def _determine_batch_size(self, batch: dict) -> int:
"""
Determine batch size from the batch dictionary.
Checks common keys first, then scans all values for tensors.
Args:
batch: Training batch dictionary.
"""
if not batch:
raise ValueError("Cannot determine batch size from empty batch")
# Check common keys first
for key in ["action", "index", "observation.state"]:
if key in batch and isinstance(batch[key], torch.Tensor):
return batch[key].shape[0]
# Scan all values for any tensor
for value in batch.values():
if isinstance(value, torch.Tensor) and value.ndim >= 1:
return value.shape[0]
# Last resort: return 1 (this handles non-tensor batches)
return 1
def get_stats(self) -> dict:
"""Return empty stats for uniform weighting."""
return {"type": "uniform"}
+398
View File
@@ -0,0 +1,398 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sample weighting infrastructure."""
from unittest.mock import Mock
import pytest
import torch
from lerobot.utils.sample_weighting import (
SampleWeighter,
SampleWeightingConfig,
UniformWeighter,
make_sample_weighter,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sample_progress_parquet(tmp_path):
"""Create a sample progress parquet file for testing."""
import pandas as pd
# Create sample progress data for 2 episodes with 10 frames each
data = {
"index": list(range(20)),
"episode_index": [0] * 10 + [1] * 10,
"frame_index": list(range(10)) * 2,
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
}
df = pd.DataFrame(data)
parquet_path = tmp_path / "sarm_progress.parquet"
df.to_parquet(parquet_path)
return parquet_path
# =============================================================================
# SampleWeightingConfig Tests
# =============================================================================
def test_config_default_values():
"""Test default configuration values."""
config = SampleWeightingConfig()
assert config.type == "rabc"
assert config.progress_path is None
assert config.head_mode == "sparse"
assert config.kappa == 0.01
assert config.epsilon == 1e-6
assert config.extra_params == {}
def test_config_custom_values():
"""Test configuration with custom values."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
head_mode="dense",
kappa=0.05,
epsilon=1e-8,
extra_params={"fallback_weight": 0.5},
)
assert config.type == "rabc"
assert config.progress_path == "/path/to/progress.parquet"
assert config.head_mode == "dense"
assert config.kappa == 0.05
assert config.epsilon == 1e-8
assert config.extra_params == {"fallback_weight": 0.5}
def test_config_uniform_type():
"""Test configuration for uniform weighting."""
config = SampleWeightingConfig(type="uniform")
assert config.type == "uniform"
# =============================================================================
# UniformWeighter Tests
# =============================================================================
def test_uniform_weighter_inherits_from_sample_weighter():
"""Test that UniformWeighter is a SampleWeighter."""
weighter = UniformWeighter(device=torch.device("cpu"))
assert isinstance(weighter, SampleWeighter)
def test_uniform_weighter_compute_batch_weights_with_action_key():
"""Test weight computation with 'action' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"action": torch.randn(8, 10)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (8,)
assert torch.allclose(weights, torch.ones(8))
assert stats["mean_weight"] == 1.0
assert stats["type"] == "uniform"
def test_uniform_weighter_compute_batch_weights_with_index_key():
"""Test weight computation with 'index' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"index": torch.arange(16)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (16,)
assert torch.allclose(weights, torch.ones(16))
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
"""Test weight computation with no tensor keys (fallback to size 1)."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"other_key": "some_value"}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (1,)
assert torch.allclose(weights, torch.ones(1))
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
"""Test that empty batch raises ValueError."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {}
with pytest.raises(ValueError, match="empty batch"):
weighter.compute_batch_weights(batch)
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
"""Test that batch size is determined by scanning all tensor values."""
weighter = UniformWeighter(device=torch.device("cpu"))
# Batch with non-standard key containing a tensor
batch = {"custom_tensor": torch.randn(7, 3)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (7,)
assert torch.allclose(weights, torch.ones(7))
def test_uniform_weighter_compute_batch_weights_on_cuda():
"""Test that weights are placed on the correct device."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
weighter = UniformWeighter(device=torch.device("cuda"))
batch = {"action": torch.randn(4, 10)}
weights, _ = weighter.compute_batch_weights(batch)
assert weights.device.type == "cuda"
def test_uniform_weighter_get_stats():
"""Test get_stats returns expected structure."""
weighter = UniformWeighter(device=torch.device("cpu"))
stats = weighter.get_stats()
assert stats == {"type": "uniform"}
# =============================================================================
# make_sample_weighter Factory Tests
# =============================================================================
def test_factory_returns_none_for_none_config():
"""Test that None config returns None weighter."""
policy = Mock()
device = torch.device("cpu")
result = make_sample_weighter(None, policy, device)
assert result is None
def test_factory_creates_uniform_weighter():
"""Test creation of UniformWeighter."""
config = SampleWeightingConfig(type="uniform")
policy = Mock()
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, UniformWeighter)
assert isinstance(weighter, SampleWeighter)
def test_factory_raises_for_unknown_type():
"""Test that unknown type raises ValueError."""
config = SampleWeightingConfig(type="unknown_type")
policy = Mock()
device = torch.device("cpu")
with pytest.raises(ValueError, match="Unknown sample weighting type"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_chunk_size():
"""Test that RABC weighter requires chunk_size in policy config."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = None # No chunk_size
device = torch.device("cpu")
with pytest.raises(ValueError, match="chunk_size"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path_or_dataset_info():
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# Should fail when no progress_path AND no dataset info
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
# The parquet file is at sample_progress_parquet, get its parent directory
dataset_root = sample_progress_parquet.parent
weighter = make_sample_weighter(
config,
policy,
device,
dataset_root=str(dataset_root),
)
assert weighter is not None
from lerobot.policies.sarm.rabc import RABCWeights
assert isinstance(weighter, RABCWeights)
def test_factory_rabc_auto_detects_from_repo_id():
"""Test that RABC weighter constructs HF path from repo_id."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# This will construct the path but fail when trying to load (file doesn't exist)
# We just verify it doesn't raise the "progress_path required" error
with pytest.raises(Exception) as exc_info:
make_sample_weighter(
config,
policy,
device,
dataset_repo_id="test-user/test-dataset",
)
# Should NOT be the "progress_path required" error - it should try to load the file
assert (
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
"""Test that RABCWeights inherits from SampleWeighter."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
assert isinstance(weighter, SampleWeighter)
def test_rabc_compute_batch_weights(sample_progress_parquet):
"""Test RABCWeights.compute_batch_weights returns correct structure."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, stats = weighter.compute_batch_weights(batch)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (4,)
assert isinstance(stats, dict)
assert "mean_weight" in stats
def test_rabc_get_stats(sample_progress_parquet):
"""Test RABCWeights.get_stats returns expected structure."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
stats = weighter.get_stats()
assert stats["type"] == "rabc"
assert "num_frames" in stats
assert "chunk_size" in stats
assert stats["chunk_size"] == 5
assert "head_mode" in stats
assert stats["head_mode"] == "sparse"
assert "delta_mean" in stats
assert "delta_std" in stats
def test_factory_creates_rabc_weighter(sample_progress_parquet):
"""Test factory creates RABCWeights with valid config."""
from lerobot.policies.sarm.rabc import RABCWeights
config = SampleWeightingConfig(
type="rabc",
progress_path=str(sample_progress_parquet),
head_mode="sparse",
kappa=0.01,
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, RABCWeights)
assert isinstance(weighter, SampleWeighter)
def test_rabc_weights_normalization(sample_progress_parquet):
"""Test that RABCWeights normalizes weights to sum to batch_size."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, _ = weighter.compute_batch_weights(batch)
# Weights should be normalized to sum approximately to batch_size
batch_size = 4
assert abs(weights.sum().item() - batch_size) < 0.1