mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c4ff7dbf52 | |||
| 2a6e6bef2b | |||
| bbd1f1f920 | |||
| 5872b04851 | |||
| 80b0f1aaa2 | |||
| 0264ac717b | |||
| 94efcea867 | |||
| faa276b8cf | |||
| fd7cb9a5c5 | |||
| 25e3119a33 | |||
| f49b8537ad | |||
| ded91ca866 | |||
| 9ebc144b30 | |||
| ba690632d9 | |||
| eb2e79e22d |
+24
-23
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,12 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
--sample_weighting.kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
--sample_weighting.kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
--sample_weighting.kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -550,8 +551,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -576,7 +577,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
@@ -67,12 +68,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -140,14 +137,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -14,6 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
|
||||
|
||||
This module implements the SampleWeighter protocol for RA-BC training,
|
||||
which weights training samples based on their task progress as measured
|
||||
by the SARM reward model.
|
||||
|
||||
The weights are computed based on progress deltas:
|
||||
delta = progress[t + chunk_size] - progress[t]
|
||||
|
||||
High-quality samples (positive progress) get higher weights, while
|
||||
samples with negative progress (going backwards) get zero weight.
|
||||
|
||||
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@@ -22,6 +36,8 @@ import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.utils.sample_weighting import SampleWeighter
|
||||
|
||||
|
||||
def resolve_hf_path(path: str | Path) -> Path:
|
||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||
@@ -34,23 +50,27 @@ def resolve_hf_path(path: str | Path) -> Path:
|
||||
return Path(path)
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
class RABCWeights(SampleWeighter):
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
This class implements the SampleWeighter ABC for use with the generic
|
||||
sample weighting infrastructure in lerobot.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
progress_path: Path to parquet file with precomputed progress values.
|
||||
Supports HuggingFace URLs (hf://datasets/...).
|
||||
chunk_size: Number of frames ahead for computing progress delta.
|
||||
head_mode: Which SARM head to use ("sparse" or "dense").
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01).
|
||||
epsilon: Small constant for numerical stability (default: 1e-6).
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
|
||||
device: Device to return tensors on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -61,7 +81,7 @@ class RABCWeights:
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
self.progress_path = resolve_hf_path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
@@ -87,8 +107,8 @@ class RABCWeights:
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
self.progress_lookup: dict[int, float] = {}
|
||||
self.episode_lookup: dict[int, int] = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
@@ -100,7 +120,7 @@ class RABCWeights:
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
self.episode_boundaries: dict[int, dict[str, int]] = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
@@ -114,7 +134,7 @@ class RABCWeights:
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
def _compute_global_stats(self) -> None:
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
@@ -138,8 +158,8 @@ class RABCWeights:
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
|
||||
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
@@ -157,18 +177,19 @@ class RABCWeights:
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
batch: Training batch containing "index" key with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size.
|
||||
- Stats dict with weighting statistics for logging.
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
|
||||
return torch.ones(batch_size, device=self.device), stats
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
@@ -183,29 +204,29 @@ class RABCWeights:
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
deltas_array = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
weights = self._compute_weights(deltas_array)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
batch_size = len(weights_tensor)
|
||||
weight_sum = weights_tensor.sum() + self.epsilon
|
||||
weights_tensor = weights_tensor * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
return weights_tensor, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
@@ -241,7 +262,7 @@ class RABCWeights:
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
Array of weights.
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
@@ -273,12 +294,13 @@ class RABCWeights:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return int(val.shape[0])
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
"""Get global statistics about the RA-BC weighting."""
|
||||
return {
|
||||
"type": "rabc",
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
@@ -63,8 +63,8 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
sample_weighter=None,
|
||||
) -> tuple[MetricsTracker, dict | None]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
@@ -80,7 +80,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -90,27 +90,31 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
if sample_weights is not None:
|
||||
# Use per-sample loss for weighted training
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
# Log weighting statistics
|
||||
if output_dict is None:
|
||||
output_dict = {}
|
||||
for key, value in weight_stats.items():
|
||||
output_dict[f"sample_weight_{key}"] = value
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -288,27 +292,19 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||
sample_weighter = None
|
||||
if cfg.sample_weighting is not None:
|
||||
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
@@ -408,7 +404,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -425,16 +421,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample weighting abstraction for training.
|
||||
|
||||
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
|
||||
that can be used during training without polluting the training script with
|
||||
policy-specific code.
|
||||
|
||||
Example usage:
|
||||
# In training config
|
||||
sample_weighting:
|
||||
type: rabc
|
||||
progress_path: hf://datasets/my-dataset/sarm_progress.parquet
|
||||
head_mode: sparse
|
||||
kappa: 0.01
|
||||
|
||||
# In training script
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
|
||||
...
|
||||
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class SampleWeighter(ABC):
|
||||
"""
|
||||
Implementations compute per-sample weights that can be used to weight
|
||||
the loss during training. This enables techniques like:
|
||||
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||
- Importance sampling
|
||||
- Curriculum learning
|
||||
- Quality-based filtering
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute per-sample weights for a training batch.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary containing at minimum an "index" key
|
||||
with global frame indices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get global statistics about the weighting strategy.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleWeightingConfig:
|
||||
"""
|
||||
Configuration for sample weighting during training.
|
||||
|
||||
This is a generic config that supports multiple weighting strategies.
|
||||
The `type` field determines which implementation to use, and `extra_params`
|
||||
contains additional type-specific parameters.
|
||||
|
||||
Attributes:
|
||||
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||
progress_path: Path to precomputed progress values (for RABC)
|
||||
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||
epsilon: Small constant for numerical stability
|
||||
extra_params: Additional type-specific parameters passed to the weighter
|
||||
"""
|
||||
|
||||
type: str = "rabc"
|
||||
progress_path: str | None = None
|
||||
head_mode: str = "sparse"
|
||||
kappa: float = 0.01
|
||||
epsilon: float = 1e-6
|
||||
# Additional type-specific params can be added here or passed via extra_params
|
||||
extra_params: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def make_sample_weighter(
|
||||
config: SampleWeightingConfig | None,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter | None:
|
||||
"""
|
||||
Factory function to create a SampleWeighter from config.
|
||||
|
||||
This keeps policy-specific initialization logic out of the training script.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration, or None to disable weighting.
|
||||
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||
device: Device to place weight tensors on.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if config.type == "rabc":
|
||||
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
|
||||
|
||||
if config.type == "uniform":
|
||||
# No-op weighter that returns uniform weights
|
||||
return UniformWeighter(device=device)
|
||||
|
||||
raise ValueError(f"Unknown sample weighting type: '{config.type}'. Supported types: 'rabc', 'uniform'")
|
||||
|
||||
|
||||
def _make_rabc_weighter(
|
||||
config: SampleWeightingConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter:
|
||||
"""Create RABC weighter with policy-specific initialization.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration.
|
||||
policy: The policy being trained (used to extract chunk_size).
|
||||
device: Device to place weight tensors on.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
# Extract chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires a policy with 'chunk_size' in its config. "
|
||||
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||
)
|
||||
|
||||
# Determine progress_path: use explicit config or auto-detect from dataset
|
||||
progress_path = config.progress_path
|
||||
if progress_path is None:
|
||||
if dataset_root:
|
||||
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
|
||||
elif dataset_repo_id:
|
||||
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
|
||||
else:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires 'progress_path' to be set, "
|
||||
"or dataset_root/dataset_repo_id for auto-detection. "
|
||||
"Generate progress values using: "
|
||||
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||
)
|
||||
|
||||
return RABCWeights(
|
||||
progress_path=progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=config.head_mode,
|
||||
kappa=config.kappa,
|
||||
epsilon=config.epsilon,
|
||||
device=device,
|
||||
**config.extra_params,
|
||||
)
|
||||
|
||||
|
||||
class UniformWeighter(SampleWeighter):
|
||||
"""
|
||||
No-op sample weighter that returns uniform weights.
|
||||
|
||||
Useful as a baseline or when you want to disable weighting without
|
||||
changing the training code structure.
|
||||
|
||||
Note:
|
||||
Batch size is determined by looking for tensor values in the batch
|
||||
dictionary. The method checks common keys like "action", "index",
|
||||
and "observation.state" first, then falls back to scanning all values.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""Return uniform weights (all ones)."""
|
||||
batch_size = self._determine_batch_size(batch)
|
||||
|
||||
weights = torch.ones(batch_size, device=self.device)
|
||||
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||
return weights, stats
|
||||
|
||||
def _determine_batch_size(self, batch: dict) -> int:
|
||||
"""
|
||||
Determine batch size from the batch dictionary.
|
||||
|
||||
Checks common keys first, then scans all values for tensors.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary.
|
||||
"""
|
||||
if not batch:
|
||||
raise ValueError("Cannot determine batch size from empty batch")
|
||||
|
||||
# Check common keys first
|
||||
for key in ["action", "index", "observation.state"]:
|
||||
if key in batch and isinstance(batch[key], torch.Tensor):
|
||||
return batch[key].shape[0]
|
||||
|
||||
# Scan all values for any tensor
|
||||
for value in batch.values():
|
||||
if isinstance(value, torch.Tensor) and value.ndim >= 1:
|
||||
return value.shape[0]
|
||||
|
||||
# Last resort: return 1 (this handles non-tensor batches)
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return empty stats for uniform weighting."""
|
||||
return {"type": "uniform"}
|
||||
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the sample weighting infrastructure."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.sample_weighting import (
|
||||
SampleWeighter,
|
||||
SampleWeightingConfig,
|
||||
UniformWeighter,
|
||||
make_sample_weighter,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_progress_parquet(tmp_path):
|
||||
"""Create a sample progress parquet file for testing."""
|
||||
import pandas as pd
|
||||
|
||||
# Create sample progress data for 2 episodes with 10 frames each
|
||||
data = {
|
||||
"index": list(range(20)),
|
||||
"episode_index": [0] * 10 + [1] * 10,
|
||||
"frame_index": list(range(10)) * 2,
|
||||
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||
df.to_parquet(parquet_path)
|
||||
return parquet_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SampleWeightingConfig Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_config_default_values():
|
||||
"""Test default configuration values."""
|
||||
config = SampleWeightingConfig()
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path is None
|
||||
assert config.head_mode == "sparse"
|
||||
assert config.kappa == 0.01
|
||||
assert config.epsilon == 1e-6
|
||||
assert config.extra_params == {}
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test configuration with custom values."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
head_mode="dense",
|
||||
kappa=0.05,
|
||||
epsilon=1e-8,
|
||||
extra_params={"fallback_weight": 0.5},
|
||||
)
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path == "/path/to/progress.parquet"
|
||||
assert config.head_mode == "dense"
|
||||
assert config.kappa == 0.05
|
||||
assert config.epsilon == 1e-8
|
||||
assert config.extra_params == {"fallback_weight": 0.5}
|
||||
|
||||
|
||||
def test_config_uniform_type():
|
||||
"""Test configuration for uniform weighting."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
assert config.type == "uniform"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UniformWeighter Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||
"""Test that UniformWeighter is a SampleWeighter."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||
"""Test weight computation with 'action' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"action": torch.randn(8, 10)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (8,)
|
||||
assert torch.allclose(weights, torch.ones(8))
|
||||
assert stats["mean_weight"] == 1.0
|
||||
assert stats["type"] == "uniform"
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||
"""Test weight computation with 'index' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"index": torch.arange(16)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (16,)
|
||||
assert torch.allclose(weights, torch.ones(16))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"other_key": "some_value"}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (1,)
|
||||
assert torch.allclose(weights, torch.ones(1))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||
"""Test that empty batch raises ValueError."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {}
|
||||
|
||||
with pytest.raises(ValueError, match="empty batch"):
|
||||
weighter.compute_batch_weights(batch)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||
"""Test that batch size is determined by scanning all tensor values."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
# Batch with non-standard key containing a tensor
|
||||
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (7,)
|
||||
assert torch.allclose(weights, torch.ones(7))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||
"""Test that weights are placed on the correct device."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||
batch = {"action": torch.randn(4, 10)}
|
||||
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.device.type == "cuda"
|
||||
|
||||
|
||||
def test_uniform_weighter_get_stats():
|
||||
"""Test get_stats returns expected structure."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats == {"type": "uniform"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# make_sample_weighter Factory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_factory_returns_none_for_none_config():
|
||||
"""Test that None config returns None weighter."""
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
result = make_sample_weighter(None, policy, device)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_factory_creates_uniform_weighter():
|
||||
"""Test creation of UniformWeighter."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, UniformWeighter)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_factory_raises_for_unknown_type():
|
||||
"""Test that unknown type raises ValueError."""
|
||||
config = SampleWeightingConfig(type="unknown_type")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_chunk_size():
|
||||
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = None # No chunk_size
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_size"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Should fail when no progress_path AND no dataset info
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||
dataset_root = sample_progress_parquet.parent
|
||||
weighter = make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=str(dataset_root),
|
||||
)
|
||||
|
||||
assert weighter is not None
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_repo_id():
|
||||
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||
# We just verify it doesn't raise the "progress_path required" error
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_repo_id="test-user/test-dataset",
|
||||
)
|
||||
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||
assert (
|
||||
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (4,)
|
||||
assert isinstance(stats, dict)
|
||||
assert "mean_weight" in stats
|
||||
|
||||
|
||||
def test_rabc_get_stats(sample_progress_parquet):
|
||||
"""Test RABCWeights.get_stats returns expected structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats["type"] == "rabc"
|
||||
assert "num_frames" in stats
|
||||
assert "chunk_size" in stats
|
||||
assert stats["chunk_size"] == 5
|
||||
assert "head_mode" in stats
|
||||
assert stats["head_mode"] == "sparse"
|
||||
assert "delta_mean" in stats
|
||||
assert "delta_std" in stats
|
||||
|
||||
|
||||
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||
"""Test factory creates RABCWeights with valid config."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=str(sample_progress_parquet),
|
||||
head_mode="sparse",
|
||||
kappa=0.01,
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
# Weights should be normalized to sum approximately to batch_size
|
||||
batch_size = 4
|
||||
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||
Reference in New Issue
Block a user