add automatic detection of the progress path

This commit is contained in:
Michel Aractingi
2026-01-14 17:08:23 +01:00
parent faa276b8cf
commit 94efcea867
4 changed files with 108 additions and 49 deletions
+13 -15
View File
@@ -465,14 +465,13 @@ This script:
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--sample_weighting.type=rabc \
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
--sample_weighting.head_mode=sparse \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
@@ -489,13 +488,13 @@ The training script automatically:
**RA-BC Arguments:**
| Argument | Description | Default |
| ----------------------------------- | ------------------------------------------------------ | --------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file (required for RABC) | (required)|
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
| Argument | Description | Default |
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa
@@ -513,11 +512,11 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
| Metric | Healthy Range | Problem Indicator |
| ----------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
@@ -553,7 +552,6 @@ accelerate launch \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--sample_weighting.type=rabc \
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \
+7 -1
View File
@@ -382,7 +382,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if is_main_process:
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
sample_weighter = make_sample_weighter(
cfg.sample_weighting,
policy,
device,
dataset_root=cfg.dataset.root,
dataset_repo_id=cfg.dataset.repo_id,
)
step = 0 # number of policy updates (forward + backward + optim)
+33 -31
View File
@@ -28,7 +28,7 @@ Example usage:
kappa: 0.01
# In training script
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
...
weights, stats = sample_weighter.compute_batch_weights(batch)
"""
@@ -37,6 +37,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import torch
@@ -63,21 +64,12 @@ class SampleWeighter(ABC):
Args:
batch: Training batch dictionary containing at minimum an "index" key
with global frame indices.
Returns:
Tuple of:
- weights: Tensor of shape (batch_size,) with per-sample weights,
normalized to sum to batch_size for stable gradients.
- stats: Dictionary with logging-friendly statistics about the weights.
"""
@abstractmethod
def get_stats(self) -> dict:
"""
Get global statistics about the weighting strategy.
Returns:
Dictionary with statistics for logging (e.g., mean delta, coverage).
"""
@@ -112,6 +104,8 @@ def make_sample_weighter(
config: SampleWeightingConfig | None,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter | None:
"""
Factory function to create a SampleWeighter from config.
@@ -122,18 +116,14 @@ def make_sample_weighter(
config: Sample weighting configuration, or None to disable weighting.
policy: The policy being trained (used to extract chunk_size, etc.)
device: Device to place weight tensors on.
Returns:
SampleWeighter instance, or None if config is None.
Raises:
ValueError: If the weighting type is unknown or required params are missing.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
if config is None:
return None
if config.type == "rabc":
return _make_rabc_weighter(config, policy, device)
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
if config.type == "uniform":
# No-op weighter that returns uniform weights
@@ -146,8 +136,18 @@ def _make_rabc_weighter(
config: SampleWeightingConfig,
policy: PreTrainedPolicy,
device: torch.device,
dataset_root: str | None = None,
dataset_repo_id: str | None = None,
) -> SampleWeighter:
"""Create RABC weighter with policy-specific initialization."""
"""Create RABC weighter with policy-specific initialization.
Args:
config: Sample weighting configuration.
policy: The policy being trained (used to extract chunk_size).
device: Device to place weight tensors on.
dataset_root: Local path to dataset root (for auto-detecting progress_path).
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
"""
# Import here to avoid circular imports and keep RABC code in SARM module
from lerobot.policies.sarm.rabc import RABCWeights
@@ -159,15 +159,23 @@ def _make_rabc_weighter(
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
)
if config.progress_path is None:
raise ValueError(
"RABC sample weighting requires 'progress_path' to be set. "
"Generate progress values using: "
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
)
# Determine progress_path: use explicit config or auto-detect from dataset
progress_path = config.progress_path
if progress_path is None:
if dataset_root:
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
elif dataset_repo_id:
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
else:
raise ValueError(
"RABC sample weighting requires 'progress_path' to be set, "
"or dataset_root/dataset_repo_id for auto-detection. "
"Generate progress values using: "
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
)
return RABCWeights(
progress_path=config.progress_path,
progress_path=progress_path,
chunk_size=chunk_size,
head_mode=config.head_mode,
kappa=config.kappa,
@@ -209,12 +217,6 @@ class UniformWeighter(SampleWeighter):
Args:
batch: Training batch dictionary.
Returns:
Batch size, or 1 if it cannot be determined.
Raises:
ValueError: If batch is empty.
"""
if not batch:
raise ValueError("Cannot determine batch size from empty batch")
+55 -2
View File
@@ -231,8 +231,8 @@ def test_factory_rabc_requires_chunk_size():
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path():
"""Test that RABC weighter requires progress_path."""
def test_factory_rabc_requires_progress_path_or_dataset_info():
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
@@ -242,10 +242,63 @@ def test_factory_rabc_requires_progress_path():
policy.config.chunk_size = 50
device = torch.device("cpu")
# Should fail when no progress_path AND no dataset info
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
# The parquet file is at sample_progress_parquet, get its parent directory
dataset_root = sample_progress_parquet.parent
weighter = make_sample_weighter(
config,
policy,
device,
dataset_root=str(dataset_root),
)
assert weighter is not None
from lerobot.policies.sarm.rabc import RABCWeights
assert isinstance(weighter, RABCWeights)
def test_factory_rabc_auto_detects_from_repo_id():
"""Test that RABC weighter constructs HF path from repo_id."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # Not provided, should auto-detect
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
# This will construct the path but fail when trying to load (file doesn't exist)
# We just verify it doesn't raise the "progress_path required" error
with pytest.raises(Exception) as exc_info:
make_sample_weighter(
config,
policy,
device,
dataset_repo_id="test-user/test-dataset",
)
# Should NOT be the "progress_path required" error - it should try to load the file
assert (
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================