mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
add automatic detection of the progress path
This commit is contained in:
+13
-15
@@ -465,14 +465,13 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
@@ -489,13 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ----------------------------------- | ------------------------------------------------------ | --------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file (required for RABC) | (required)|
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -513,11 +512,11 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
@@ -553,7 +552,6 @@ accelerate launch \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
|
||||
@@ -382,7 +382,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ Example usage:
|
||||
kappa: 0.01
|
||||
|
||||
# In training script
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device)
|
||||
sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id)
|
||||
...
|
||||
weights, stats = sample_weighter.compute_batch_weights(batch)
|
||||
"""
|
||||
@@ -37,6 +37,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -63,21 +64,12 @@ class SampleWeighter(ABC):
|
||||
Args:
|
||||
batch: Training batch dictionary containing at minimum an "index" key
|
||||
with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- weights: Tensor of shape (batch_size,) with per-sample weights,
|
||||
normalized to sum to batch_size for stable gradients.
|
||||
- stats: Dictionary with logging-friendly statistics about the weights.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get global statistics about the weighting strategy.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics for logging (e.g., mean delta, coverage).
|
||||
"""
|
||||
|
||||
|
||||
@@ -112,6 +104,8 @@ def make_sample_weighter(
|
||||
config: SampleWeightingConfig | None,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter | None:
|
||||
"""
|
||||
Factory function to create a SampleWeighter from config.
|
||||
@@ -122,18 +116,14 @@ def make_sample_weighter(
|
||||
config: Sample weighting configuration, or None to disable weighting.
|
||||
policy: The policy being trained (used to extract chunk_size, etc.)
|
||||
device: Device to place weight tensors on.
|
||||
|
||||
Returns:
|
||||
SampleWeighter instance, or None if config is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the weighting type is unknown or required params are missing.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if config.type == "rabc":
|
||||
return _make_rabc_weighter(config, policy, device)
|
||||
return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id)
|
||||
|
||||
if config.type == "uniform":
|
||||
# No-op weighter that returns uniform weights
|
||||
@@ -146,8 +136,18 @@ def _make_rabc_weighter(
|
||||
config: SampleWeightingConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
dataset_root: str | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> SampleWeighter:
|
||||
"""Create RABC weighter with policy-specific initialization."""
|
||||
"""Create RABC weighter with policy-specific initialization.
|
||||
|
||||
Args:
|
||||
config: Sample weighting configuration.
|
||||
policy: The policy being trained (used to extract chunk_size).
|
||||
device: Device to place weight tensors on.
|
||||
dataset_root: Local path to dataset root (for auto-detecting progress_path).
|
||||
dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path).
|
||||
"""
|
||||
# Import here to avoid circular imports and keep RABC code in SARM module
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
@@ -159,15 +159,23 @@ def _make_rabc_weighter(
|
||||
"This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc."
|
||||
)
|
||||
|
||||
if config.progress_path is None:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires 'progress_path' to be set. "
|
||||
"Generate progress values using: "
|
||||
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||
)
|
||||
# Determine progress_path: use explicit config or auto-detect from dataset
|
||||
progress_path = config.progress_path
|
||||
if progress_path is None:
|
||||
if dataset_root:
|
||||
progress_path = str(Path(dataset_root) / "sarm_progress.parquet")
|
||||
elif dataset_repo_id:
|
||||
progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet"
|
||||
else:
|
||||
raise ValueError(
|
||||
"RABC sample weighting requires 'progress_path' to be set, "
|
||||
"or dataset_root/dataset_repo_id for auto-detection. "
|
||||
"Generate progress values using: "
|
||||
"python -m lerobot.policies.sarm.compute_rabc_weights --help"
|
||||
)
|
||||
|
||||
return RABCWeights(
|
||||
progress_path=config.progress_path,
|
||||
progress_path=progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=config.head_mode,
|
||||
kappa=config.kappa,
|
||||
@@ -209,12 +217,6 @@ class UniformWeighter(SampleWeighter):
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary.
|
||||
|
||||
Returns:
|
||||
Batch size, or 1 if it cannot be determined.
|
||||
|
||||
Raises:
|
||||
ValueError: If batch is empty.
|
||||
"""
|
||||
if not batch:
|
||||
raise ValueError("Cannot determine batch size from empty batch")
|
||||
|
||||
@@ -231,8 +231,8 @@ def test_factory_rabc_requires_chunk_size():
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path():
|
||||
"""Test that RABC weighter requires progress_path."""
|
||||
def test_factory_rabc_requires_progress_path_or_dataset_info():
|
||||
"""Test that RABC weighter requires progress_path or dataset info for auto-detection."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
@@ -242,10 +242,63 @@ def test_factory_rabc_requires_progress_path():
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Should fail when no progress_path AND no dataset info
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet):
|
||||
"""Test that RABC weighter auto-detects progress_path from dataset_root."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
# The parquet file is at sample_progress_parquet, get its parent directory
|
||||
dataset_root = sample_progress_parquet.parent
|
||||
weighter = make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=str(dataset_root),
|
||||
)
|
||||
|
||||
assert weighter is not None
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
|
||||
|
||||
def test_factory_rabc_auto_detects_from_repo_id():
|
||||
"""Test that RABC weighter constructs HF path from repo_id."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # Not provided, should auto-detect
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
# This will construct the path but fail when trying to load (file doesn't exist)
|
||||
# We just verify it doesn't raise the "progress_path required" error
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
make_sample_weighter(
|
||||
config,
|
||||
policy,
|
||||
device,
|
||||
dataset_repo_id="test-user/test-dataset",
|
||||
)
|
||||
# Should NOT be the "progress_path required" error - it should try to load the file
|
||||
assert (
|
||||
"progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
Reference in New Issue
Block a user