diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index aae7372fa..85a79ef17 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -42,6 +42,10 @@
- local: xvla
title: X-VLA
title: "Policies"
+- sections:
+ - local: sarm
+ title: SARM
+ title: "Reward Models"
- sections:
- local: async
title: Use Async Inference
diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx
new file mode 100644
index 000000000..321097692
--- /dev/null
+++ b/docs/source/sarm.mdx
@@ -0,0 +1,586 @@
+# SARM: Stage-Aware Reward Modeling
+
+SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
+
+**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
+
+## Why Reward Models?
+
+Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
+
+## Overview
+
+SARM has following features:
+
+1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
+2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
+3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
+
+SARM trains on a compact **stage+tau** target for each frame:
+
+- **stage**: integer stage index `k ∈ {0, ..., K-1}`
+- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
+- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
+
+At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
+
+This matches **Formula (2)** from the paper:
+
+```
+progress_t = P_{k-1} + α̅_k × τ_t
+```
+
+Where:
+
+- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
+- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
+- `α̅_k` is the temporal proportion for subtask k
+
+This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
+
+## Inputs and Targets (What the new code expects)
+
+SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
+
+- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
+- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
+- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
+- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
+
+At minimum, each training sample needs:
+
+- `task` (string): task description
+- `policy.image_key` images and `policy.state_key` states from the dataset
+
+---
+
+## Annotation Modes
+
+You can choose from **3 annotation modes** that determine how progress labels are computed:
+
+| Mode | Annotations Required | Heads | Use Case |
+| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
+| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
+| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
+| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
+
+### Mode Details
+
+
+
+
+**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
+
+- **Sparse head**: 1 stage ("task"), linear progress
+- **Dense head**: Not used
+- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
+
+## Set Up Your Environment
+
+1. Install LeRobot by following our [Installation Guide](./installation).
+2. Install SARM dependencies by running:
+
+```bash
+pip install -e ".[sarm]"
+```
+
+Workflow:
+
+```
+1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
+```
+
+
+
+
+**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
+
+- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
+- **Dense head**: Multiple fine-grained stages from VLM annotations
+- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
+
+Workflow:
+
+```
+1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
+```
+
+
+
+
+**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
+
+- **Sparse head**: High-level stages from VLM annotations
+- **Dense head**: Fine-grained stages from VLM annotations
+- **Best for**: Complex multi-stage tasks where both granularities are useful
+
+Workflow:
+
+```
+1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
+```
+
+
+
+
+---
+
+## Step 1: Subtask Annotation
+
+
+
+
+**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
+
+
+
+
+Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
+
+```bash
+python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
+ --repo-id your-username/your-dataset \
+ --dense-only \
+ --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
+ --video-key observation.images.base \
+ --num-workers 4 \
+ --push-to-hub
+```
+
+**What gets saved:**
+
+- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
+- `meta/temporal_proportions_dense.json` - Dense temporal proportions
+- Per-episode columns in `episodes/*.parquet`:
+ - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
+ - (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
+
+
+
+
+Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
+
+```bash
+python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
+ --repo-id your-username/your-dataset \
+ --sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
+ --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
+ --video-key observation.images.base \
+ --num-workers 4 \
+ --push-to-hub
+```
+
+**What gets saved:**
+
+- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
+- `meta/temporal_proportions_dense.json` - Dense temporal proportions
+- Per-episode columns in `episodes/*.parquet`:
+ - `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
+ - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
+ - (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
+
+
+
+
+### Annotation Arguments
+
+| Argument | Description |
+| ---------------------- | ------------------------------------------------------------------------------- |
+| `--repo-id` | HuggingFace dataset repository ID |
+| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
+| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
+| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
+| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
+| `--num-workers` | Number of parallel GPU workers (default: 1) |
+| `--episodes` | Specific episode indices to annotate (default: all) |
+| `--skip-existing` | Skip episodes that already have annotations |
+| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
+| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
+
+> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
+
+---
+
+## Step 2: Verify Annotations
+
+
+
+
+**No verification needed!** Skip this step.
+
+
+
+
+Visualize annotations using the `--visualize-only` flag:
+
+```bash
+python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
+ --repo-id your-username/your-dataset \
+ --visualize-only \
+ --visualize-type dense \
+ --num-visualizations 5 \
+ --video-key observation.images.base \
+ --output-dir ./subtask_viz
+```
+
+
+
+
+Visualize annotations using the `--visualize-only` flag:
+
+```bash
+python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
+ --repo-id your-username/your-dataset \
+ --visualize-only \
+ --visualize-type both \
+ --num-visualizations 5 \
+ --video-key observation.images.base \
+ --output-dir ./subtask_viz
+```
+
+
+
+
+This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
+
+### Visualization Arguments
+
+| Argument | Description |
+| ---------------------- | -------------------------------------------------------------- |
+| `--visualize-only` | Only visualize existing annotations (no generation) |
+| `--num-visualizations` | Number of episodes to visualize (default: 5) |
+| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
+
+**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
+
+---
+
+## Step 3: Train SARM
+
+
+
+
+Train with **no annotations** - uses linear progress from 0 to 1:
+
+```bash
+python src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your-username/your-dataset \
+ --policy.type=sarm \
+ --policy.annotation_mode=single_stage \
+ --policy.image_key=observation.images.base \
+ --output_dir=outputs/train/sarm_single \
+ --batch_size=32 \
+ --steps=5000 \
+ --wandb.enable=true \
+ --wandb.project=sarm \
+ --policy.repo_id=your-username/your-model-name
+```
+
+
+
+
+Train with **dense annotations only** (sparse auto-generated):
+
+```bash
+python src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your-username/your-dataset \
+ --policy.type=sarm \
+ --policy.annotation_mode=dense_only \
+ --policy.image_key=observation.images.base \
+ --output_dir=outputs/train/sarm_dense \
+ --batch_size=32 \
+ --steps=5000 \
+ --wandb.enable=true \
+ --wandb.project=sarm \
+ --policy.repo_id=your-username/your-model-name
+```
+
+
+
+
+Train with **both sparse and dense annotations**:
+
+```bash
+python src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your-username/your-dataset \
+ --policy.type=sarm \
+ --policy.annotation_mode=dual \
+ --policy.image_key=observation.images.base \
+ --output_dir=outputs/train/sarm_dual \
+ --batch_size=32 \
+ --steps=5000 \
+ --wandb.enable=true \
+ --wandb.project=sarm \
+ --policy.repo_id=your-username/your-model-name
+```
+
+
+
+
+### Multi-GPU Training
+
+Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
+
+### Training Arguments
+
+| Argument | Description | Default |
+| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
+| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
+| `--policy.image_key` | Camera key for images | `observation.images.top` |
+| `--policy.state_key` | Key for joint states | `observation.state` |
+| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
+| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
+
+---
+
+## Step 4: Visualize Predictions
+
+Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
+
+
+
+
+```bash
+python src/lerobot/policies/sarm/compute_rabc_weights.py \
+ --dataset-repo-id your-username/your-dataset \
+ --reward-model-path your-username/sarm-model \
+ --visualize-only \
+ --num-visualizations 5 \
+ --head-mode sparse \
+ --output-dir ./sarm_viz
+```
+
+
+
+
+```bash
+python src/lerobot/policies/sarm/compute_rabc_weights.py \
+ --dataset-repo-id your-username/your-dataset \
+ --reward-model-path your-username/sarm-model \
+ --visualize-only \
+ --num-visualizations 5 \
+ --head-mode dense \
+ --output-dir ./sarm_viz
+```
+
+
+
+
+```bash
+python src/lerobot/policies/sarm/compute_rabc_weights.py \
+ --dataset-repo-id your-username/your-dataset \
+ --reward-model-path your-username/sarm-model \
+ --visualize-only \
+ --num-visualizations 5 \
+ --head-mode both \
+ --output-dir ./sarm_viz
+```
+
+
+
+
+The visualization shows:
+
+- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
+- **Stage probabilities**: Stacked area plot of predicted stage probabilities
+- **Sample frames**: Key frames from the episode with progress/stage labels
+
+### Visualization Arguments
+
+| Argument | Description |
+| ---------------------- | --------------------------------------------------------- |
+| `--visualize-only` | Only visualize predictions (no RABC computation) |
+| `--num-visualizations` | Number of episodes to visualize (default: 5) |
+| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
+| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
+
+---
+
+## Step 5 (Optional): Train Policy with RA-BC
+
+Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
+
+1. **Precompute progress values** for all frames using the trained SARM model
+2. **Train policy** with RA-BC weighting using the precomputed values
+
+### How RA-BC Works
+
+For each training sample, RA-BC computes the progress delta:
+
+```
+r_i = φ(o_{t+Δ}) - φ(o_t)
+```
+
+Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
+
+The weighting follows **Equations 8-9** from the paper:
+
+- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
+- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
+
+### Step 5a: Compute SARM Progress Values
+
+First, run the SARM model on all frames in your dataset to compute progress values:
+
+```bash
+python src/lerobot/policies/sarm/compute_rabc_weights.py \
+ --dataset-repo-id your-username/your-dataset \
+ --reward-model-path your-username/sarm-model \
+ --head-mode sparse \
+ --num-visualizations 5 \
+ --push-to-hub
+```
+
+This script:
+
+- Processes all frames and computes progress values
+- Saves progress values to a parquet file next to the dataset on disk (defaults to `/sarm_progress.parquet`)
+- Generates visualizations of the first N episodes (default: 5)
+
+**Arguments:**
+
+| Argument | Description | Default |
+| ---------------------- | -------------------------------------------------------------- | ---------- |
+| `--reward-model-path` | Path to trained SARM model | (required) |
+| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
+| `--device` | Device for inference | `cuda` |
+| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
+| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
+
+**Output format** (`sarm_progress.parquet`):
+
+| Column | Description |
+| ----------------- | ---------------------------------------------- |
+| `index` | Global frame index in dataset |
+| `episode_index` | Episode number |
+| `frame_index` | Local frame index within episode |
+| `progress_sparse` | Sparse head progress value [0, 1] |
+| `progress_dense` | Dense head progress value [0, 1] (if computed) |
+
+### Step 5b: Train Policy with RA-BC
+
+Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
+
+```bash
+python src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your-username/your-dataset \
+ --policy.type=pi0 \
+ --use_rabc=true \
+ --rabc_head_mode=sparse \
+ --rabc_kappa=0.01 \
+ --output_dir=outputs/train/policy_rabc \
+ --batch_size=32 \
+ --steps=40000
+```
+
+The training script automatically:
+
+- Loads the precomputed progress values from the parquet file
+- Uses the policy's `chunk_size` to compute progress deltas (Δ)
+- Computes sample weights based on progress improvement
+- Applies weighted loss during training
+
+**RA-BC Arguments:**
+
+| Argument | Description | Default |
+| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
+| `--use_rabc` | Enable RA-BC sample weighting | `false` |
+| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
+| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
+| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
+
+### Tuning RA-BC Kappa
+
+The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
+
+**How the weighting works:**
+
+| Condition | Weight |
+| ------------------- | ----------------------- |
+| `delta > kappa` | 1.0 (hard threshold) |
+| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
+| `delta < 0` | 0.0 (negative progress) |
+
+**Diagnosing kappa issues:**
+
+Monitor these WandB metrics during training:
+
+| Metric | Healthy Range | Problem Indicator |
+| ------------------ | ------------- | ------------------------- |
+| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
+| `rabc_delta_mean` | > 0 | Should be positive |
+| `rabc_delta_std` | > 0 | Variance in data quality |
+
+**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
+
+**Setting kappa based on your data:**
+
+The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
+
+```
+# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
+# Most deltas fall in range [0.01, 0.05]
+
+# Option 1: Set kappa = delta_mean (medium selectivity)
+--rabc_kappa=0.03
+
+# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
+--rabc_kappa=0.05
+
+# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
+--rabc_kappa=0.07
+```
+
+**When RA-BC may not help:**
+
+If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
+
+### Multi-GPU Training with RA-BC
+
+```bash
+accelerate launch \
+ --multi_gpu \
+ --num_processes=4 \
+ src/lerobot/scripts/lerobot_train.py \
+ --dataset.repo_id=your-username/your-dataset \
+ --policy.type=pi0 \
+ --use_rabc=true \
+ --rabc_kappa=0.01 \
+ --output_dir=outputs/train/policy_rabc \
+ --batch_size=32 \
+ --steps=40000
+```
+
+---
+
+## Tips & Best Practices
+
+### Choosing a Mode
+
+- **Start with `single_stage`** for quick experiments - no annotation overhead
+- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
+- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
+
+### Annotation Quality
+
+1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
+2. **Verify with visualization**: Always check a few episodes before training
+3. **Consistent naming**: Use the same subtask names across all episodes
+
+### RA-BC
+
+1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
+2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
+
+---
+
+## Citation
+
+```bibtex
+@article{chen2025sarm,
+ title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
+ author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
+ journal={arXiv preprint arXiv:2509.25358},
+ year={2025}
+}
+```
diff --git a/pyproject.toml b/pyproject.toml
index 9458c0127..ed3e2ae43 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -96,7 +96,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
-transformers-dep = ["transformers>=4.53.0,<5.0.0"]
+transformers-dep = ["transformers>=4.57.1,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
# Motors
@@ -133,6 +133,7 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
+sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -173,6 +174,7 @@ all = [
"lerobot[phone]",
"lerobot[libero]",
"lerobot[metaworld]",
+ "lerobot[sarm]"
]
[project.scripts]
diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py
index 13a8d6525..cee9dfdf9 100644
--- a/src/lerobot/configs/train.py
+++ b/src/lerobot/configs/train.py
@@ -65,9 +65,17 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
- checkpoint_path: Path | None = field(init=False, default=None)
+
+ # RA-BC (Reward-Aligned Behavior Cloning) parameters
+ use_rabc: bool = False # Enable reward-weighted training
+ rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
+ rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
+ rabc_epsilon: float = 1e-6 # Small constant for numerical stability
+ rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
+
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
+ checkpoint_path: Path | None = field(init=False, default=None)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
@@ -131,6 +139,14 @@ class TrainPipelineConfig(HubMixin):
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
+ if self.use_rabc and not self.rabc_progress_path:
+ # Auto-detect from dataset path
+ repo_id = self.dataset.repo_id
+ if self.dataset.root:
+ self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
+ else:
+ self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
+
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
diff --git a/src/lerobot/data_processing/__init__.py b/src/lerobot/data_processing/__init__.py
new file mode 100644
index 000000000..2f76d5676
--- /dev/null
+++ b/src/lerobot/data_processing/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/lerobot/data_processing/sarm_annotations/__init__.py b/src/lerobot/data_processing/sarm_annotations/__init__.py
new file mode 100644
index 000000000..2f76d5676
--- /dev/null
+++ b/src/lerobot/data_processing/sarm_annotations/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py
new file mode 100644
index 000000000..67e37bab8
--- /dev/null
+++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py
@@ -0,0 +1,1202 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+SARM Subtask Annotation using local GPU (Qwen3-VL).
+
+This script implements the annotation approach from the SARM paper using local GPU inference:
+"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation"
+Paper: https://arxiv.org/pdf/2509.25358
+
+What it does:
+1. Takes videos from a LeRobot dataset
+2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur
+3. Saves subtask timestamps to the dataset metadata
+4. Optionally pushes the annotated dataset to HuggingFace Hub
+
+SARM trains reward models that predict:
+ - Stage: Which subtask is currently being executed (discrete classification)
+ - Progress: How far along the subtask we are (continuous 0-1)
+
+Supports three annotation modes:
+ 1. No annotations (no args): Auto-creates single sparse "task" stage covering full episode.
+ Use with SARM config annotation_mode="single_stage" for simple tasks.
+
+ 2. Dense-only (--dense-only --dense-subtasks): Dense annotations from VLM, auto-generated
+ single sparse "task" stage. Use with annotation_mode="dense_only".
+
+ 3. Dual mode (--sparse-subtasks + --dense-subtasks): Both sparse and dense annotations
+ from VLM. Use with annotation_mode="dual".
+
+Requirements:
+ - GPU with sufficient VRAM (16GB+ recommended for 30B model)
+ - `pip install transformers, torch, qwen-vl-utils`
+
+Run with:
+```bash
+python examples/dataset_annotation/subtask_annotation.py \
+ --repo-id your-username/your-dataset \
+ --sparse-subtasks "Do ..." \
+ --dense-subtasks "Do task 1, Do task 2, Do task 3" \
+ --video-key observation.images.base \
+ --push-to-hub
+```
+"""
+
+import argparse
+import json
+import multiprocessing as mp
+import random
+import re
+import subprocess
+import tempfile
+import textwrap
+import time
+from concurrent.futures import ProcessPoolExecutor, as_completed
+from pathlib import Path
+from typing import Any
+
+import cv2
+import numpy as np
+import pandas as pd
+import torch
+from pydantic import BaseModel, Field
+from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+
+# Pydantic Models for SARM Subtask Annotation
+class Timestamp(BaseModel):
+ """Timestamp in MM:SS or SS format"""
+
+ start: str = Field(description="Start timestamp (MM:SS or just seconds)")
+ end: str = Field(description="End timestamp (MM:SS or just seconds)")
+
+
+class Subtask(BaseModel):
+ """Individual subtask/stage - must use EXACT names from provided list"""
+
+ name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
+ timestamps: Timestamp
+
+
+class SubtaskAnnotation(BaseModel):
+ """Complete annotation for a robot manipulation episode"""
+
+ subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
+
+
+def compute_temporal_proportions(
+ annotations: dict[int, Any], fps: int = 30, subtask_order: list[str] | None = None
+) -> dict[str, float]:
+ """
+ Compute dataset-level temporal proportions (priors) for each subtask.
+
+ Implements SARM Paper Formula (1): ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
+
+ Args:
+ annotations: Dict mapping episode index to SubtaskAnnotation object.
+ fps: Frames per second (unused, kept for API compatibility)
+ subtask_order: Optional list defining the output order of subtasks.
+
+ Returns:
+ Dict mapping subtask name to its temporal proportion (ᾱ_k), ordered by subtask_order if provided.
+ """
+ subtask_proportions: dict[str, list[float]] = {}
+
+ for annotation in annotations.values():
+ total_duration = 0
+ durations: dict[str, int] = {}
+
+ for subtask in annotation.subtasks:
+ start_parts = subtask.timestamps.start.split(":")
+ end_parts = subtask.timestamps.end.split(":")
+
+ start_seconds = (
+ int(start_parts[0]) * 60 + int(start_parts[1])
+ if len(start_parts) == 2
+ else int(start_parts[0])
+ )
+ end_seconds = (
+ int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
+ )
+
+ duration = end_seconds - start_seconds
+ durations[subtask.name] = duration
+ total_duration += duration
+
+ if total_duration > 0:
+ for name, duration in durations.items():
+ if name not in subtask_proportions:
+ subtask_proportions[name] = []
+ subtask_proportions[name].append(duration / total_duration)
+
+ if not subtask_proportions:
+ return {}
+
+ avg_proportions = {name: sum(props) / len(props) for name, props in subtask_proportions.items()}
+
+ total = sum(avg_proportions.values())
+ if total > 0:
+ avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
+
+ # Reorder according to subtask_order if provided
+ if subtask_order:
+ avg_proportions = {
+ name: avg_proportions.get(name, 0.0) for name in subtask_order if name in avg_proportions
+ }
+
+ return avg_proportions
+
+
+def create_sarm_prompt(subtask_list: list[str]) -> str:
+ subtask_str = "\n".join([f" - {name}" for name in subtask_list])
+
+ return textwrap.dedent(f"""\
+ # Role
+ You are a Robotics Vision System specializing in temporal action localization for robot manipulation. Your job is to segment a single demonstration video into distinct, non-overlapping atomic actions from a fixed subtask list.
+
+ # Subtask Label Set (Closed Vocabulary)
+ You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
+
+ [
+ {subtask_str}
+ ]
+
+ The video shows one successful execution of all subtasks in a logical order.
+
+ # Ground-Truth Semantics (Very Important)
+ Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks.
+
+ - A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask.
+ - A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration.
+
+ If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask.
+
+ # Hard Constraints & Logic
+ 1. **Continuous Coverage (No Gaps):**
+ - The entire video duration from "00:00" to the final timestamp must be covered by subtasks.
+ - There can be no gaps between subtasks.
+ - If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it.
+
+ 2. **Boundary Consistency:**
+ - The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask.
+ - Boundaries must coincide with a real visual state transition, not just a convenient time split.
+
+ 3. **Chronological Order, One Occurrence Each:**
+ - This is a single successful demonstration.
+ - Each subtask from the vocabulary appears **exactly once**, in the correct logical order.
+ - **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video.
+
+ 4. **Reject Uniform Segmentation (Important):**
+ - Do NOT simply divide the video into equal or nearly equal time chunks.
+ - If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries.
+ - Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare).
+
+ 5. **Timestamps:**
+ - Timestamps must be in `"MM:SS"` format.
+ - The first subtask always starts at `"00:00"`.
+ - The last subtask ends at the final visible frame of the video.
+
+ # Step 1 — Textual Timeline (must do this first)
+ First, write a extensive and detailed textual timeline describing what happens in the video with approximate timestamps.
+ For each subtask, include:
+ - its name
+ - an approximate start and end time,
+ - an description of the visual event at the boundary (e.g. "shirt fully folded to the left", "robot rotates folded shirt 90 degrees").
+
+ Format this as a bullet list.
+
+ # Step 2 — JSON Output (final answer)
+ After the textual timeline, output **only** valid JSON with this structure.
+ The JSON **must** be consistent with the textual timeline above:
+
+ {{
+ "subtasks": [
+ {{
+ "name": "EXACT_NAME_FROM_LIST",
+ "timestamps": {{
+ "start": "MM:SS",
+ "end": "MM:SS"
+ }}
+ }},
+ {{
+ "name": "EXACT_NAME_FROM_LIST",
+ "timestamps": {{
+ "start": "MM:SS",
+ "end": "MM:SS"
+ }}
+ }}
+ ]
+ }}
+
+ Do not add any extra keys to the JSON.
+ """)
+
+
+class VideoAnnotator:
+ """Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
+
+ def __init__(
+ self,
+ subtask_list: list[str],
+ model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
+ device: str = "cuda",
+ torch_dtype: torch.dtype = torch.bfloat16,
+ model: Qwen3VLMoeForConditionalGeneration | None = None, # noqa: F821
+ processor: AutoProcessor | None = None, # noqa: F821
+ ):
+ """
+ Initialize the video annotator with local model.
+
+ Args:
+ subtask_list: List of allowed subtask names (for consistency)
+ model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct)
+ device: Device to use (cuda, cpu)
+ torch_dtype: Data type for model (bfloat16, float16, float32)
+ model: Pre-loaded model instance (optional, to share between annotators)
+ processor: Pre-loaded processor instance (optional, to share between annotators)
+ """
+ self.subtask_list = subtask_list
+ self.prompt = create_sarm_prompt(subtask_list)
+ self.device = device
+
+ # Use provided model/processor or load new ones
+ if model is not None and processor is not None:
+ self.model = model
+ self.processor = processor
+ print(f"Using shared model on {device}")
+ else:
+ from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
+
+ print(f"Loading model: {model_name}...")
+
+ self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
+ model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
+ )
+
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
+
+ print(f"Model loaded successfully on {device}")
+
+ def extract_episode_segment(
+ self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1
+ ) -> Path:
+ """
+ Extract a specific episode segment from concatenated video.
+ Uses minimal compression to preserve quality for local inference.
+
+ Args:
+ file_path: Path to the concatenated video file
+ start_timestamp: Starting timestamp in seconds (within this video file)
+ end_timestamp: Ending timestamp in seconds (within this video file)
+ target_fps: Target FPS (default: 1 for faster processing)
+
+ Returns:
+ Path to extracted video file
+ """
+ # Create temporary file for extracted video
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
+ tmp_path = Path(tmp_file.name)
+
+ try:
+ # Check if ffmpeg is available
+ subprocess.run( # nosec B607
+ ["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
+ )
+ except (subprocess.CalledProcessError, FileNotFoundError) as err:
+ raise RuntimeError("ffmpeg not found, cannot extract episode segment") from err
+
+ try:
+ # Calculate duration
+ duration = end_timestamp - start_timestamp
+
+ print(f"Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)")
+
+ # Use ffmpeg to extract segment with minimal quality loss
+ cmd = [
+ "ffmpeg",
+ "-i",
+ str(file_path),
+ "-ss",
+ str(start_timestamp),
+ "-t",
+ str(duration),
+ "-r",
+ str(target_fps),
+ "-c:v",
+ "libx264",
+ "-preset",
+ "ultrafast",
+ "-crf",
+ "23",
+ "-an",
+ "-y",
+ str(tmp_path),
+ ]
+
+ subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
+
+ # Verify the output file was created and is not empty
+ if not tmp_path.exists() or tmp_path.stat().st_size == 0:
+ print("Video extraction failed (0 bytes) - skipping episode")
+ if tmp_path.exists():
+ tmp_path.unlink()
+ raise RuntimeError("FFmpeg produced empty video file")
+
+ # Show extraction results
+ file_size_mb = tmp_path.stat().st_size / (1024 * 1024)
+
+ # Fail if file is too small (< 100KB likely means extraction failed)
+ if file_size_mb < 0.1:
+ print(f"Extracted video too small ({file_size_mb:.2f}MB) - skipping episode")
+ tmp_path.unlink()
+ raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
+
+ print(f"Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)")
+
+ return tmp_path
+
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(f"ffmpeg failed ({e})") from e
+
+ def annotate(
+ self,
+ file_path: str | Path,
+ fps: int,
+ start_timestamp: float = 0.0,
+ end_timestamp: float | None = None,
+ max_retries: int = 3,
+ ) -> SubtaskAnnotation:
+ """Annotate a video segment using local GPU."""
+ from qwen_vl_utils import process_vision_info
+
+ file_path = Path(file_path)
+
+ if end_timestamp is None:
+ cap = cv2.VideoCapture(str(file_path))
+ end_timestamp = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / (cap.get(cv2.CAP_PROP_FPS) or 1)
+ cap.release()
+
+ duration = end_timestamp - start_timestamp
+ duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
+
+ extracted_path = self.extract_episode_segment(file_path, start_timestamp, end_timestamp, 1)
+ is_extracted = extracted_path != file_path
+
+ try:
+ messages = [
+ {"role": "system", "content": [{"type": "text", "text": self.prompt}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "video", "video": str(extracted_path), "fps": 1.0},
+ {
+ "type": "text",
+ "text": f"Video is {duration_str} (~{duration:.1f}s). Follow instructions.",
+ },
+ ],
+ },
+ ]
+
+ for attempt in range(max_retries):
+ try:
+ text = self.processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ image_inputs, video_inputs = process_vision_info(messages)
+ inputs = self.processor(
+ text=[text],
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ ).to(self.device)
+
+ with torch.no_grad():
+ generated_ids = self.model.generate(
+ **inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
+ )
+
+ response = self.processor.batch_decode(
+ [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
+ skip_special_tokens=True,
+ )[0].strip()
+
+ # Extract JSON
+ if "```json" in response:
+ response = response.split("```json")[1].split("```")[0]
+ elif "```" in response:
+ response = response.split("```")[1].split("```")[0]
+
+ try:
+ return SubtaskAnnotation.model_validate(json.loads(response))
+ except json.JSONDecodeError:
+ match = re.search(r"\{.*\}", response, re.DOTALL)
+ if match:
+ return SubtaskAnnotation.model_validate(json.loads(match.group()))
+ raise ValueError("No JSON found") from None
+ except Exception as e:
+ if attempt == max_retries - 1:
+ raise RuntimeError(f"Failed after {max_retries} attempts") from e
+ time.sleep(1)
+ finally:
+ if is_extracted and extracted_path.exists():
+ extracted_path.unlink()
+
+
+def display_annotation(annotation: SubtaskAnnotation, episode_idx: int, fps: int, prefix: str = ""):
+ """Display annotation summary."""
+ subtask_summary = ", ".join(
+ f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks
+ )
+ print(f"Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}")
+
+
+def timestamp_to_seconds(timestamp: str) -> float:
+ """Convert MM:SS or SS timestamp to seconds"""
+ parts = timestamp.split(":")
+ if len(parts) == 2:
+ return int(parts[0]) * 60 + int(parts[1])
+ else:
+ return int(parts[0])
+
+
+def extract_frame(video_path: Path, timestamp: float) -> np.ndarray | None:
+ """Extract a single frame from video at given timestamp."""
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ return None
+ cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
+ ret, frame = cap.read()
+ cap.release()
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None
+
+
+def draw_timeline(ax, subtasks, total_duration, colors):
+ """Draw a timeline with color-coded subtask segments."""
+ import matplotlib.patches as mpatches
+
+ bar_height, bar_y = 0.6, 0.5
+
+ for i, subtask in enumerate(subtasks):
+ start = timestamp_to_seconds(subtask.timestamps.start)
+ end = timestamp_to_seconds(subtask.timestamps.end)
+ color = colors[i % len(colors)]
+
+ rect = mpatches.FancyBboxPatch(
+ (start, bar_y - bar_height / 2),
+ end - start,
+ bar_height,
+ boxstyle="round,pad=0.02,rounding_size=0.1",
+ facecolor=color,
+ edgecolor="white",
+ linewidth=1.5,
+ alpha=0.85,
+ )
+ ax.add_patch(rect)
+
+ # Add label if segment is wide enough
+ duration = end - start
+ if duration > total_duration * 0.06:
+ ax.text(
+ (start + end) / 2,
+ bar_y,
+ subtask.name,
+ ha="center",
+ va="center",
+ fontsize=8,
+ fontweight="bold",
+ color="white",
+ rotation=0 if duration > total_duration * 0.12 else 45,
+ )
+
+ if i > 0:
+ ax.axvline(x=start, ymin=0.1, ymax=0.9, color="white", linestyle="--", linewidth=1.5, alpha=0.7)
+
+ ax.axvline(x=0, ymin=0.1, ymax=0.9, color="#00ff00", linestyle="-", linewidth=2, alpha=0.9)
+ if subtasks:
+ ax.axvline(
+ x=timestamp_to_seconds(subtasks[-1].timestamps.end),
+ ymin=0.1,
+ ymax=0.9,
+ color="white",
+ linestyle="--",
+ linewidth=1.5,
+ alpha=0.7,
+ )
+
+ ax.set_xlim(-total_duration * 0.02, total_duration * 1.02)
+ ax.set_ylim(-0.1, 1.1)
+ ax.set_xlabel("Time (seconds)", fontsize=10, color="white", labelpad=5)
+ for spine in ["top", "right", "left"]:
+ ax.spines[spine].set_visible(False)
+ ax.spines["bottom"].set_color("#444444")
+ ax.tick_params(axis="x", colors="#888888", labelsize=8)
+ ax.tick_params(axis="y", left=False, labelleft=False)
+
+
+def visualize_episode(
+ ep_idx: int,
+ annotation: SubtaskAnnotation,
+ video_path: Path,
+ video_start: float,
+ video_end: float,
+ output_path: Path,
+ video_key: str,
+ ann_type: str,
+):
+ """Create visualization for a single episode with frames and timeline."""
+ import matplotlib.pyplot as plt
+
+ if annotation is None:
+ print(f"No {ann_type} annotation for episode {ep_idx}")
+ return
+
+ subtasks = annotation.subtasks
+ if not subtasks:
+ print(f"No subtasks for episode {ep_idx}")
+ return
+
+ colors = plt.cm.tab10(np.linspace(0, 1, max(len(subtasks), 10)))
+ total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
+
+ # Extract middle frame from each subtask
+ sample_frames, frame_times = [], []
+ for subtask in subtasks:
+ start = timestamp_to_seconds(subtask.timestamps.start)
+ end = timestamp_to_seconds(subtask.timestamps.end)
+ mid = (start + end) / 2
+ frame_times.append(mid)
+ sample_frames.append(extract_frame(video_path, video_start + mid))
+
+ # Create figure
+ fig_width = max(16, len(subtasks) * 2.5)
+ fig = plt.figure(figsize=(fig_width, 10))
+ fig.patch.set_facecolor("#1a1a2e")
+
+ gs = fig.add_gridspec(
+ 2,
+ max(len(subtasks), 1),
+ height_ratios=[2, 1],
+ hspace=0.3,
+ wspace=0.1,
+ left=0.05,
+ right=0.95,
+ top=0.88,
+ bottom=0.1,
+ )
+
+ fig.suptitle(
+ f"Episode {ep_idx} - {ann_type.capitalize()} Annotations",
+ fontsize=18,
+ fontweight="bold",
+ color="white",
+ y=0.96,
+ )
+ fig.text(
+ 0.5,
+ 0.91,
+ f"Camera: {video_key} | Duration: {video_end - video_start:.1f}s | {len(subtasks)} subtasks",
+ ha="center",
+ fontsize=11,
+ color="#888888",
+ )
+
+ # Plot frames
+ for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks, strict=True)):
+ ax = fig.add_subplot(gs[0, i])
+ ax.set_facecolor("#16213e")
+ if frame is not None:
+ ax.imshow(frame)
+ else:
+ ax.text(
+ 0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, color="white", transform=ax.transAxes
+ )
+ ax.set_title(subtask.name, fontsize=10, fontweight="bold", color=colors[i % len(colors)], pad=8)
+ ax.axis("off")
+ ax.text(
+ 0.5,
+ -0.08,
+ f"t={frame_times[i]:.1f}s",
+ ha="center",
+ fontsize=9,
+ color="#888888",
+ transform=ax.transAxes,
+ )
+
+ # Plot timeline
+ ax_timeline = fig.add_subplot(gs[1, :])
+ ax_timeline.set_facecolor("#16213e")
+ draw_timeline(ax_timeline, subtasks, total_duration, colors)
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor="none", bbox_inches="tight")
+ plt.close()
+ print(f"Saved: {output_path}")
+
+
+def visualize_annotations(
+ dataset: LeRobotDataset,
+ sparse_annotations: dict[int, SubtaskAnnotation],
+ dense_annotations: dict[int, SubtaskAnnotation] | None,
+ video_key: str,
+ output_dir: Path,
+ num_episodes: int = 5,
+ annotation_type: str = "sparse",
+ episode_indices: list[int] | None = None,
+):
+ """
+ Visualize subtask annotations for a set of episodes.
+
+ Args:
+ dataset: LeRobotDataset instance
+ sparse_annotations: Dict mapping episode index to sparse annotations
+ dense_annotations: Dict mapping episode index to dense annotations (or None)
+ video_key: Camera/video key to use
+ output_dir: Directory to save visualization images
+ num_episodes: Number of episodes to visualize (ignored if episode_indices provided)
+ annotation_type: "sparse", "dense", or "both"
+ episode_indices: Specific episode indices to visualize (optional)
+ """
+ # Determine available episodes based on annotation type
+ if annotation_type == "sparse":
+ available = set(sparse_annotations.keys())
+ elif annotation_type == "dense":
+ available = set(dense_annotations.keys()) if dense_annotations else set()
+ else: # both
+ sparse_set = set(sparse_annotations.keys())
+ dense_set = set(dense_annotations.keys()) if dense_annotations else set()
+ available = sparse_set | dense_set
+
+ if not available:
+ print("Error: No annotations found to visualize.")
+ return
+
+ # Select episodes to visualize
+ if episode_indices:
+ episodes = sorted([e for e in episode_indices if e in available])
+ missing = set(episode_indices) - available
+ if missing:
+ print(f"Episodes not found in annotations: {sorted(missing)}")
+ else:
+ episodes = sorted(random.sample(list(available), min(num_episodes, len(available))))
+ print(f"Visualizing {len(episodes)} episodes: {episodes}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Generate visualizations
+ for i, ep_idx in enumerate(episodes, 1):
+ print(f"Processing episode {ep_idx} ({i}/{len(episodes)})")
+ video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
+ if not video_path.exists():
+ print(f"Video not found: {video_path}")
+ continue
+
+ video_start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
+ video_end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
+
+ if annotation_type == "both":
+ # Visualize both sparse and dense
+ for ann_type, annotations in [("sparse", sparse_annotations), ("dense", dense_annotations)]:
+ if annotations and ep_idx in annotations:
+ output_path = output_dir / f"episode_{ep_idx:04d}_{ann_type}.png"
+ visualize_episode(
+ ep_idx,
+ annotations.get(ep_idx),
+ video_path,
+ video_start,
+ video_end,
+ output_path,
+ video_key,
+ ann_type,
+ )
+ else:
+ annotations = sparse_annotations if annotation_type == "sparse" else dense_annotations
+ if annotations and ep_idx in annotations:
+ output_path = output_dir / f"episode_{ep_idx:04d}_{annotation_type}.png"
+ visualize_episode(
+ ep_idx,
+ annotations.get(ep_idx),
+ video_path,
+ video_start,
+ video_end,
+ output_path,
+ video_key,
+ annotation_type,
+ )
+
+ print(f"Visualizations saved to: {output_dir.absolute()}")
+
+
+def save_annotations_to_dataset(
+ dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
+):
+ """Save annotations to LeRobot dataset parquet format."""
+ from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
+
+ episodes_dataset = load_episodes(dataset_path)
+ if not episodes_dataset or len(episodes_dataset) == 0:
+ return
+
+ episodes_df = episodes_dataset.to_pandas()
+ cols = [
+ f"{prefix}_{c}"
+ for c in [
+ "subtask_names",
+ "subtask_start_times",
+ "subtask_end_times",
+ "subtask_start_frames",
+ "subtask_end_frames",
+ ]
+ ]
+ for col in cols:
+ episodes_df[col] = None
+
+ for ep_idx, ann in annotations.items():
+ if ep_idx >= len(episodes_df):
+ continue
+ names, starts, ends, start_frames, end_frames = [], [], [], [], []
+ for s in ann.subtasks:
+ names.append(s.name)
+ st, et = timestamp_to_seconds(s.timestamps.start), timestamp_to_seconds(s.timestamps.end)
+ starts.append(st)
+ ends.append(et)
+ start_frames.append(int(st * fps))
+ end_frames.append(int(et * fps))
+ episodes_df.at[ep_idx, cols[0]] = names
+ episodes_df.at[ep_idx, cols[1]] = starts
+ episodes_df.at[ep_idx, cols[2]] = ends
+ episodes_df.at[ep_idx, cols[3]] = start_frames
+ episodes_df.at[ep_idx, cols[4]] = end_frames
+
+ # Group by file and write
+ for ep_idx in episodes_df.index:
+ key = (
+ episodes_df.loc[ep_idx, "meta/episodes/chunk_index"],
+ episodes_df.loc[ep_idx, "meta/episodes/file_index"],
+ )
+ path = dataset_path / DEFAULT_EPISODES_PATH.format(chunk_index=key[0], file_index=key[1])
+ if path.exists():
+ file_df = pd.read_parquet(path)
+ for col in cols + (
+ [
+ "subtask_names",
+ "subtask_start_times",
+ "subtask_end_times",
+ "subtask_start_frames",
+ "subtask_end_frames",
+ ]
+ if prefix == "sparse"
+ else []
+ ):
+ if col not in file_df.columns:
+ file_df[col] = None
+ if ep_idx in annotations:
+ for col in cols:
+ file_df.at[ep_idx, col] = episodes_df.loc[ep_idx, col]
+ if prefix == "sparse": # Legacy columns
+ for i, legacy in enumerate(
+ [
+ "subtask_names",
+ "subtask_start_times",
+ "subtask_end_times",
+ "subtask_start_frames",
+ "subtask_end_frames",
+ ]
+ ):
+ file_df.at[ep_idx, legacy] = episodes_df.loc[ep_idx, cols[i]]
+ file_df.to_parquet(path, engine="pyarrow", compression="snappy")
+
+
+def generate_auto_sparse_annotations(
+ dataset: LeRobotDataset, episode_indices: list[int], video_key: str
+) -> dict[int, SubtaskAnnotation]:
+ """Auto-generate single 'task' stage annotations for all episodes."""
+ annotations = {}
+ for ep_idx in episode_indices:
+ start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
+ end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
+ duration = end - start
+ end_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
+ annotations[ep_idx] = SubtaskAnnotation(
+ subtasks=[Subtask(name="task", timestamps=Timestamp(start="00:00", end=end_str))]
+ )
+ return annotations
+
+
+def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
+ """Load annotations from LeRobot dataset parquet files."""
+ from lerobot.datasets.utils import load_episodes
+
+ episodes_dataset = load_episodes(dataset_path)
+ if not episodes_dataset or len(episodes_dataset) == 0:
+ return {}
+
+ col_names = f"{prefix}_subtask_names"
+ col_start = f"{prefix}_subtask_start_times"
+ col_end = f"{prefix}_subtask_end_times"
+
+ # Fall back to legacy columns for sparse
+ if col_names not in episodes_dataset.column_names:
+ if prefix == "sparse" and "subtask_names" in episodes_dataset.column_names:
+ col_names, col_start, col_end = "subtask_names", "subtask_start_times", "subtask_end_times"
+ else:
+ return {}
+
+ df = episodes_dataset.to_pandas()
+ annotations = {}
+ for ep_idx in df.index:
+ names = df.loc[ep_idx, col_names]
+ if names is None or (isinstance(names, float) and pd.isna(names)):
+ continue
+ starts, ends = df.loc[ep_idx, col_start], df.loc[ep_idx, col_end]
+ annotations[int(ep_idx)] = SubtaskAnnotation(
+ subtasks=[
+ Subtask(
+ name=n,
+ timestamps=Timestamp(
+ start=f"{int(s) // 60:02d}:{int(s) % 60:02d}",
+ end=f"{int(e) // 60:02d}:{int(e) % 60:02d}",
+ ),
+ )
+ for n, s, e in zip(names, starts, ends, strict=True)
+ ]
+ )
+ return annotations
+
+
+def process_single_episode(
+ ep_idx: int,
+ dataset_root: Path,
+ dataset_meta,
+ video_key: str,
+ fps: int,
+ annotator: VideoAnnotator,
+) -> tuple[int, SubtaskAnnotation | None, str | None]:
+ """Process a single episode annotation."""
+ try:
+ video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key)
+ if not video_path.exists():
+ return ep_idx, None, f"Video not found: {video_path}"
+
+ start = float(dataset_meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
+ end = float(dataset_meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
+ return ep_idx, annotator.annotate(video_path, fps, start, end), None
+ except Exception as e:
+ return ep_idx, None, str(e)
+
+
+def worker_process_episodes(
+ worker_id: int,
+ gpu_id: int,
+ episode_indices: list[int],
+ repo_id: str,
+ video_key: str,
+ sparse_subtask_list: list[str],
+ dense_subtask_list: list[str] | None,
+ model_name: str,
+ torch_dtype: torch.dtype,
+) -> tuple[dict, dict | None]:
+ """Worker for parallel processing across GPUs."""
+ device = f"cuda:{gpu_id}"
+ dataset = LeRobotDataset(repo_id, download_videos=False)
+
+ sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype)
+ dense_annotator = (
+ VideoAnnotator(
+ dense_subtask_list,
+ model_name,
+ device,
+ torch_dtype,
+ sparse_annotator.model,
+ sparse_annotator.processor,
+ )
+ if dense_subtask_list
+ else None
+ )
+
+ sparse_annotations, dense_annotations = {}, {} if dense_subtask_list else None
+
+ for ep_idx in episode_indices:
+ _, sparse_ann, err = process_single_episode(
+ ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator
+ )
+ if sparse_ann:
+ sparse_annotations[ep_idx] = sparse_ann
+
+ if dense_annotator:
+ _, dense_ann, _ = process_single_episode(
+ ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator
+ )
+ if dense_ann:
+ dense_annotations[ep_idx] = dense_ann
+
+ return sparse_annotations, dense_annotations
+
+
+def main():
+ parser = argparse.ArgumentParser(description="SARM-style subtask annotation using local GPU (Qwen3-VL)")
+ parser.add_argument("--repo-id", type=str, required=True, help="HuggingFace dataset repository ID")
+ parser.add_argument(
+ "--sparse-subtasks", type=str, default=None, help="Comma-separated sparse subtask names"
+ )
+ parser.add_argument(
+ "--dense-subtasks", type=str, default=None, help="Comma-separated dense subtask names"
+ )
+ parser.add_argument(
+ "--dense-only", action="store_true", help="Dense-only mode with auto-generated sparse 'task' stage"
+ )
+ parser.add_argument("--episodes", type=int, nargs="+", default=None, help="Episode indices to annotate")
+ parser.add_argument("--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="VLM model")
+ parser.add_argument("--skip-existing", action="store_true", help="Skip already annotated episodes")
+ parser.add_argument("--video-key", type=str, default=None, help="Video key (default: first available)")
+ parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
+ parser.add_argument("--output-repo-id", type=str, default=None, help="Output repo ID for push")
+ parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
+ parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
+ parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU")
+ parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use")
+ # Visualization options
+ parser.add_argument(
+ "--visualize-only",
+ action="store_true",
+ help="Only visualize existing annotations (no generation)",
+ )
+ parser.add_argument(
+ "--num-visualizations",
+ type=int,
+ default=5,
+ help="Number of episodes to visualize (default: 5)",
+ )
+ parser.add_argument(
+ "--visualize-type",
+ type=str,
+ default="sparse",
+ choices=["sparse", "dense", "both"],
+ help="Type of annotations to visualize (default: sparse)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="./subtask_viz",
+ help="Output directory for visualizations (default: ./subtask_viz)",
+ )
+
+ args = parser.parse_args()
+
+ # Load dataset first (needed for both annotation and visualization)
+ print(f"Loading dataset: {args.repo_id}")
+ dataset = LeRobotDataset(args.repo_id, download_videos=True)
+ fps = dataset.fps
+
+ if not dataset.meta.video_keys:
+ raise ValueError("No video keys found")
+
+ video_key = (
+ args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0]
+ )
+ print(f"Using camera: {video_key}, FPS: {fps}")
+
+ # Handle visualization-only mode
+ if args.visualize_only:
+ print("Visualization-only mode")
+ sparse_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse")
+ dense_annotations = load_annotations_from_dataset(dataset.root, prefix="dense")
+
+ if not sparse_annotations and not dense_annotations:
+ return print("Error: No annotations found. Run annotation first.")
+
+ print(f"Found {len(sparse_annotations)} sparse, {len(dense_annotations)} dense annotations")
+
+ visualize_annotations(
+ dataset=dataset,
+ sparse_annotations=sparse_annotations,
+ dense_annotations=dense_annotations if dense_annotations else None,
+ video_key=video_key,
+ output_dir=Path(args.output_dir),
+ num_episodes=args.num_visualizations,
+ annotation_type=args.visualize_type,
+ episode_indices=args.episodes,
+ )
+ return
+
+ # Validate arguments for annotation mode
+ if args.dense_only and not args.dense_subtasks:
+ return print("Error: --dense-only requires --dense-subtasks")
+ if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only:
+ return print("Error: --dense-subtasks requires --sparse-subtasks or --dense-only")
+
+ sparse_subtask_list = (
+ [s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None
+ )
+ dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None
+ auto_sparse = sparse_subtask_list is None
+ dense_mode = dense_subtask_list is not None
+ torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
+
+ # Determine episodes
+ episode_indices = args.episodes or list(range(dataset.meta.total_episodes))
+
+ existing_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse")
+ if args.skip_existing:
+ episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
+
+ if not episode_indices:
+ return print("All episodes already annotated!")
+ print(f"Annotating {len(episode_indices)} episodes")
+
+ # GPU setup
+ gpu_ids = args.gpu_ids or list(
+ range(min(args.num_workers, torch.cuda.device_count() if torch.cuda.is_available() else 1))
+ )
+ args.num_workers = len(gpu_ids)
+
+ sparse_annotations = existing_annotations.copy()
+ dense_annotations = {} if dense_mode else None
+
+ # Auto-sparse mode
+ if auto_sparse:
+ sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key))
+ save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
+ print(f"Auto-generated {len(episode_indices)} sparse 'task' annotations")
+
+ # VLM annotation (for sparse if not auto, and for dense)
+ need_vlm = (not auto_sparse) or dense_mode
+
+ if need_vlm:
+ if args.num_workers > 1 and not auto_sparse:
+ # Parallel processing
+ print(f"Parallel processing with {args.num_workers} workers")
+ episodes_per_worker = [[] for _ in range(args.num_workers)]
+ for i, ep_idx in enumerate(episode_indices):
+ episodes_per_worker[i % args.num_workers].append(ep_idx)
+
+ with ProcessPoolExecutor(
+ max_workers=args.num_workers, mp_context=mp.get_context("spawn")
+ ) as executor:
+ futures = [
+ executor.submit(
+ worker_process_episodes,
+ w,
+ gpu_ids[w],
+ episodes_per_worker[w],
+ args.repo_id,
+ video_key,
+ sparse_subtask_list,
+ dense_subtask_list,
+ args.model,
+ torch_dtype,
+ )
+ for w in range(args.num_workers)
+ if episodes_per_worker[w]
+ ]
+
+ for future in as_completed(futures):
+ try:
+ worker_sparse, worker_dense = future.result()
+ sparse_annotations.update(worker_sparse)
+ if dense_mode and worker_dense:
+ dense_annotations.update(worker_dense)
+ save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
+ if dense_mode:
+ save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
+ except Exception as e:
+ raise RuntimeError(f"Worker failed: {e}") from e
+ else:
+ # Sequential processing
+ sparse_annotator = (
+ VideoAnnotator(sparse_subtask_list, args.model, args.device, torch_dtype)
+ if not auto_sparse and sparse_subtask_list
+ else None
+ )
+ dense_annotator = (
+ VideoAnnotator(
+ dense_subtask_list,
+ args.model,
+ args.device,
+ torch_dtype,
+ sparse_annotator.model if sparse_annotator else None,
+ sparse_annotator.processor if sparse_annotator else None,
+ )
+ if dense_mode
+ else None
+ )
+
+ for i, ep_idx in enumerate(episode_indices):
+ print(f"Episode {ep_idx} ({i + 1}/{len(episode_indices)})")
+
+ if sparse_annotator:
+ _, sparse_ann, err = process_single_episode(
+ ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator
+ )
+ if sparse_ann:
+ sparse_annotations[ep_idx] = sparse_ann
+ save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
+ elif err:
+ print(f"Sparse failed: {err}")
+
+ if dense_annotator:
+ _, dense_ann, err = process_single_episode(
+ ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator
+ )
+ if dense_ann:
+ dense_annotations[ep_idx] = dense_ann
+ save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
+ elif err:
+ print(f"Dense failed: {err}")
+
+ # Save temporal proportions
+ def save_proportions(annotations, prefix, subtask_list=None, is_auto=False):
+ props: dict[str, float] = (
+ {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps, subtask_list)
+ )
+ path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json"
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w") as f:
+ json.dump(props, f, indent=2)
+ print(f"Saved {prefix} temporal proportions")
+
+ save_proportions(sparse_annotations, "sparse", sparse_subtask_list, auto_sparse)
+ if dense_mode and dense_annotations:
+ save_proportions(dense_annotations, "dense", dense_subtask_list)
+
+ print(f"\nComplete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations")
+
+ # Visualize annotations after generation
+ if args.num_visualizations > 0:
+ print(f"\nGenerating {args.num_visualizations} visualizations...")
+ visualize_type = "both" if dense_mode else "sparse"
+ visualize_annotations(
+ dataset=dataset,
+ sparse_annotations=sparse_annotations,
+ dense_annotations=dense_annotations,
+ video_key=video_key,
+ output_dir=Path(args.output_dir),
+ num_episodes=args.num_visualizations,
+ annotation_type=visualize_type,
+ )
+
+ if args.push_to_hub:
+ try:
+ dataset.push_to_hub(push_videos=True)
+ print(f"Pushed to {args.output_repo_id or args.repo_id}")
+ except Exception as e:
+ print(f"Push failed: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py
index 788542d49..ceefb0d56 100644
--- a/src/lerobot/policies/__init__.py
+++ b/src/lerobot/policies/__init__.py
@@ -29,6 +29,7 @@ __all__ = [
"PI0Config",
"PI05Config",
"SmolVLAConfig",
+ "SARMConfig",
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py
index b7cbcd061..a5c48eb3d 100644
--- a/src/lerobot/policies/act/modeling_act.py
+++ b/src/lerobot/policies/act/modeling_act.py
@@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy):
def __init__(
self,
config: ACTConfig,
+ **kwargs,
):
"""
Args:
diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py
index 3ab6719cb..1fdc76f10 100644
--- a/src/lerobot/policies/diffusion/modeling_diffusion.py
+++ b/src/lerobot/policies/diffusion/modeling_diffusion.py
@@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy):
def __init__(
self,
config: DiffusionConfig,
+ **kwargs,
):
"""
Args:
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index 3d17fa7dc..eb1ff41f7 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -37,6 +37,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
+from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
@@ -105,6 +106,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
+ elif name == "sarm":
+ from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
+
+ return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
@@ -337,6 +342,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
+ elif isinstance(policy_cfg, SARMConfig):
+ from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
+
+ processors = make_sarm_pre_post_processors(
+ config=policy_cfg,
+ dataset_stats=kwargs.get("dataset_stats"),
+ dataset_meta=kwargs.get("dataset_meta"),
+ )
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
@@ -435,6 +448,13 @@ def make_policy(
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
+ # Pass dataset_stats to the policy if available (needed for some policies like SARM)
+ if ds_meta is not None and hasattr(ds_meta, "stats"):
+ kwargs["dataset_stats"] = ds_meta.stats
+
+ if ds_meta is not None:
+ kwargs["dataset_meta"] = ds_meta
+
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py
index 605f7a097..bdaef37b9 100644
--- a/src/lerobot/policies/groot/modeling_groot.py
+++ b/src/lerobot/policies/groot/modeling_groot.py
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
name = "groot"
config_class = GrootConfig
- def __init__(self, config: GrootConfig):
+ def __init__(self, config: GrootConfig, **kwargs):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py
index 4b79c2902..0d9c77e00 100644
--- a/src/lerobot/policies/pi0/modeling_pi0.py
+++ b/src/lerobot/policies/pi0/modeling_pi0.py
@@ -907,6 +907,7 @@ class PI0Policy(PreTrainedPolicy):
def __init__(
self,
config: PI0Config,
+ **kwargs,
):
"""
Args:
@@ -1235,9 +1236,15 @@ class PI0Policy(PreTrainedPolicy):
return actions
- def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
- """Run the batch through the model and compute the loss for training."""
+ def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
+ """Run the batch through the model and compute the loss for training.
+ Args:
+ batch: Training batch containing observations and actions.
+ reduction: How to reduce the loss. Options:
+ - "mean": Return scalar mean loss (default, backward compatible)
+ - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
+ """
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
@@ -1251,11 +1258,17 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, :original_action_dim]
- loss = losses.mean()
-
loss_dict = {
- "loss": loss.item(),
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}
- return loss, loss_dict
+ if reduction == "none":
+ # Return per-sample losses (B,) by averaging over time and action dims
+ per_sample_loss = losses.mean(dim=(1, 2))
+ loss_dict["loss"] = per_sample_loss.mean().item()
+ return per_sample_loss, loss_dict
+ else:
+ # Default: return scalar mean loss
+ loss = losses.mean()
+ loss_dict["loss"] = loss.item()
+ return loss, loss_dict
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
index 64eb4cb23..2cd142042 100644
--- a/src/lerobot/policies/pi05/modeling_pi05.py
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
@@ -880,6 +880,7 @@ class PI05Policy(PreTrainedPolicy):
def __init__(
self,
config: PI05Config,
+ **kwargs,
):
"""
Args:
@@ -1209,9 +1210,15 @@ class PI05Policy(PreTrainedPolicy):
return actions
- def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
- """Run the batch through the model and compute the loss for training."""
+ def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
+ """Run the batch through the model and compute the loss for training.
+ Args:
+ batch: Training batch containing observations and actions.
+ reduction: How to reduce the loss. Options:
+ - "mean": Return scalar mean loss (default, backward compatible)
+ - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
+ """
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
@@ -1225,11 +1232,17 @@ class PI05Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, :original_action_dim]
- loss = losses.mean()
-
loss_dict = {
- "loss": loss.item(),
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
}
- return loss, loss_dict
+ if reduction == "none":
+ # Return per-sample losses (B,) by averaging over time and action dims
+ per_sample_loss = losses.mean(dim=(1, 2))
+ loss_dict["loss"] = per_sample_loss.mean().item()
+ return per_sample_loss, loss_dict
+ else:
+ # Default: return scalar mean loss
+ loss = losses.mean()
+ loss_dict["loss"] = loss.item()
+ return loss, loss_dict
diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py
new file mode 100644
index 000000000..5b6ea6e9b
--- /dev/null
+++ b/src/lerobot/policies/sarm/compute_rabc_weights.py
@@ -0,0 +1,870 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
+
+This script processes all frames in a dataset with SARM to compute progress values [0, 1].
+The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
+
+Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
+need ~num_frames/30 queries instead of one per frame (~30x speedup).
+
+Usage:
+ # Full RA-BC computation with visualizations
+ python src/lerobot/policies/sarm/compute_rabc_weights.py \\
+ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
+ --reward-model-path pepijn223/sarm_single_uni4
+
+ # Faster computation with stride (compute every 5 frames, interpolate the rest)
+ python src/lerobot/policies/sarm/compute_rabc_weights.py \\
+ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
+ --reward-model-path pepijn223/sarm_single_uni4 \\
+ --stride 5
+
+ # Visualize predictions only (no RA-BC computation)
+ python src/lerobot/policies/sarm/compute_rabc_weights.py \\
+ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
+ --reward-model-path pepijn223/sarm_single_uni4 \\
+ --visualize-only \\
+ --num-visualizations 5
+
+The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import matplotlib.gridspec as gridspec
+import matplotlib.pyplot as plt
+import numpy as np
+import pyarrow as pa
+import pyarrow.parquet as pq
+import torch
+from tqdm import tqdm
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
+from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
+from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
+
+
+def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
+ """Read reward_model_path from parquet metadata if available."""
+ if not parquet_path.exists():
+ return None
+ try:
+ metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
+ if metadata and b"reward_model_path" in metadata:
+ return metadata[b"reward_model_path"].decode()
+ except Exception: # nosec B110
+ return None
+ return None
+
+
+def load_sarm_resources(
+ dataset_repo_id: str,
+ reward_model_path: str,
+ device: str = "cuda",
+) -> tuple[LeRobotDataset, SARMRewardModel, any]:
+ """
+ Load SARM model, dataset, and preprocessor.
+
+ Returns:
+ Tuple of (dataset, reward_model, preprocessor)
+ """
+ logging.info(f"Loading model: {reward_model_path}")
+ reward_model = SARMRewardModel.from_pretrained(reward_model_path)
+ reward_model.config.device = device
+ reward_model.to(device).eval()
+
+ image_key = reward_model.config.image_key
+ state_key = reward_model.config.state_key
+ delta_indices = reward_model.config.observation_delta_indices
+
+ logging.info(f"Loading dataset: {dataset_repo_id}")
+ temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
+ fps = temp_dataset.fps
+
+ delta_timestamps = {
+ image_key: [idx / fps for idx in delta_indices],
+ state_key: [idx / fps for idx in delta_indices],
+ }
+ dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
+ logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
+
+ preprocess, _ = make_sarm_pre_post_processors(
+ config=reward_model.config,
+ dataset_stats=dataset.meta.stats,
+ dataset_meta=dataset.meta,
+ )
+
+ return dataset, reward_model, preprocess
+
+
+def to_numpy_image(img) -> np.ndarray:
+ """Convert image tensor to numpy uint8 (H, W, C)."""
+ if isinstance(img, torch.Tensor):
+ img = img.cpu().numpy()
+ if img.ndim == 4:
+ # Take center frame for bidirectional sampling
+ img = img[img.shape[0] // 2]
+ if img.shape[0] in [1, 3]:
+ img = np.transpose(img, (1, 2, 0))
+ if img.dtype != np.uint8:
+ # Handle normalized images (may have negative values or values > 1)
+ img = img.astype(np.float32)
+ img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
+ img = (img * 255).astype(np.uint8)
+ return img
+
+
+def visualize_episode(
+ frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
+):
+ """Create visualization with progress plot, stage probabilities, and sample frames.
+
+ Same as sarm_inference_visualization.py
+ """
+ num_stages = stage_preds.shape[1]
+ colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
+ frame_indices = np.arange(len(progress_preds))
+
+ fig = plt.figure(figsize=(14, 12))
+ gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
+ ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
+
+ # Progress plot
+ ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
+ ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
+ if gt_progress is not None:
+ ax_progress.plot(
+ frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
+ )
+ ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
+ ax_progress.set_ylabel("Progress")
+ ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
+ ax_progress.set_ylim(-0.05, 1.1)
+ ax_progress.legend(loc="upper left")
+ ax_progress.grid(True, alpha=0.3)
+
+ # Stage predictions
+ ax_stages.stackplot(
+ frame_indices,
+ *[stage_preds[:, i] for i in range(num_stages)],
+ colors=colors,
+ alpha=0.8,
+ labels=stage_labels,
+ )
+ if gt_stages is not None:
+ for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
+ ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
+ ax_stages.set_xlabel("Frame")
+ ax_stages.set_ylabel("Stage Probability")
+ ax_stages.set_ylim(0, 1)
+ ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
+ ax_stages.grid(True, alpha=0.3)
+
+ # Sample frames
+ ax_frames.axis("off")
+ num_sample = 8
+ sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
+ h, w = frames[0].shape[:2]
+ combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
+ for i, idx in enumerate(sample_indices):
+ frame = frames[idx]
+ if frame.shape[-1] == 1:
+ frame = np.repeat(frame, 3, axis=-1)
+ combined[:, i * w : (i + 1) * w] = frame
+ stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
+ ax_frames.text(
+ i * w + w / 2,
+ -10,
+ f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
+ ha="center",
+ va="top",
+ fontsize=7,
+ )
+ ax_frames.imshow(combined)
+ ax_frames.set_title("Sample Frames", pad=20)
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved: {output_path}")
+
+
+def visualize_sarm_predictions(
+ dataset: LeRobotDataset,
+ reward_model: SARMRewardModel,
+ preprocess,
+ episode_indices: list[int],
+ head_mode: str,
+ output_dir: Path,
+ num_display_frames: int = 5,
+ stride: int = 1,
+):
+ """
+ Visualize SARM predictions for multiple episodes.
+
+ Computes predictions for every frame by default. With stride > 1, computes predictions
+ every N frames and interpolates (progress + stage probabilities) for visualization.
+
+ Args:
+ dataset: LeRobotDataset with delta_timestamps configured
+ reward_model: Loaded SARM model
+ preprocess: Preprocessor from make_sarm_pre_post_processors
+ episode_indices: List of episode indices to visualize
+ head_mode: "sparse", "dense", or "both"
+ output_dir: Directory to save visualizations
+ num_display_frames: Number of frames to display in thumbnail strip (default: 5)
+ stride: Compute predictions every N frames, interpolate the rest (default: 1)
+ """
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ image_key = reward_model.config.image_key
+ state_key = reward_model.config.state_key
+ dual_mode = reward_model.config.uses_dual_heads
+ device = reward_model.device
+
+ # Center frame index for bidirectional sampling
+ target_idx = reward_model.config.n_obs_steps // 2
+
+ # Determine which heads to visualize
+ schemes_to_viz = []
+ if head_mode in ("sparse", "both") or not dual_mode:
+ schemes_to_viz.append("sparse")
+ if head_mode in ("dense", "both") and dual_mode:
+ schemes_to_viz.append("dense")
+
+ # Set preprocessor to eval mode to disable augmentations
+ if hasattr(preprocess, "eval"):
+ preprocess.eval()
+ for step in preprocess.steps:
+ if hasattr(step, "eval"):
+ step.eval()
+
+ for episode_idx in episode_indices:
+ ep = dataset.meta.episodes[episode_idx]
+ ep_start = ep["dataset_from_index"]
+ ep_end = ep["dataset_to_index"]
+ task = dataset[ep_start].get("task", "perform the task")
+ num_frames = ep_end - ep_start
+
+ # Select frames for display thumbnails (evenly sampled from begin to end)
+ display_indices = set(
+ [
+ ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
+ for i in range(num_display_frames)
+ ]
+ if num_frames >= num_display_frames
+ else list(range(ep_start, ep_end))
+ )
+ viz_frames = {}
+
+ # Load display frames up-front (stride mode might skip them otherwise).
+ for frame_idx in display_indices:
+ sample = dataset[frame_idx]
+ viz_frames[frame_idx] = to_numpy_image(sample[image_key])
+
+ # Initialize storage for each scheme
+ scheme_data = {}
+ for scheme in schemes_to_viz:
+ num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
+ scheme_data[scheme] = {
+ "viz_progress": np.full(num_frames, np.nan),
+ "viz_stages": np.full((num_frames, num_stages), np.nan),
+ "viz_gt_progress": np.full(num_frames, np.nan),
+ "viz_gt_stages": np.full(num_frames, np.nan),
+ "target_key": f"{scheme}_targets",
+ "num_stages": num_stages,
+ "temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
+ "subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
+ }
+
+ if stride > 1:
+ logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
+
+ # Process frames one at a time to avoid memory buildup
+ frame_indices = list(range(ep_start, ep_end, stride))
+ if (ep_end - 1) not in frame_indices:
+ frame_indices.append(ep_end - 1)
+ frame_indices = sorted(set(frame_indices))
+
+ for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
+ local_idx = frame_idx - ep_start
+ sample = dataset[frame_idx]
+
+ batch = {
+ image_key: sample[image_key],
+ "task": task,
+ "index": frame_idx,
+ "episode_index": episode_idx,
+ }
+ if state_key in sample:
+ batch[state_key] = sample[state_key]
+
+ with torch.no_grad():
+ processed = preprocess(batch)
+ video_features = processed["video_features"].to(device)
+ text_features = processed["text_features"].to(device)
+ state_features = processed.get("state_features")
+ if state_features is not None:
+ state_features = state_features.to(device)
+ lengths = processed.get("lengths")
+
+ for scheme in schemes_to_viz:
+ sd = scheme_data[scheme]
+
+ # Ground truth
+ # In stride visualization mode, ground-truth plots can be misleading
+ # (only sparse points are available), so we skip GT.
+ if stride == 1 and sd["target_key"] in processed:
+ gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
+ sd["viz_gt_stages"][local_idx] = int(gt_target)
+ sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
+ gt_target,
+ num_stages=sd["num_stages"],
+ temporal_proportions=sd["temporal_props"],
+ subtask_names=sd["subtask_names"],
+ )
+
+ # Predictions
+ reward, stage_probs = reward_model.calculate_rewards(
+ text_embeddings=text_features,
+ video_embeddings=video_features,
+ state_features=state_features,
+ lengths=lengths,
+ return_all_frames=True,
+ return_stages=True,
+ head_mode=scheme,
+ )
+
+ # Handle both tensor and numpy outputs
+ if isinstance(reward, torch.Tensor):
+ reward = reward.cpu().numpy()
+ stage_probs = stage_probs.cpu().numpy()
+
+ if reward.ndim == 2:
+ sd["viz_progress"][local_idx] = reward[0, target_idx]
+ sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
+ else:
+ sd["viz_progress"][local_idx] = reward[target_idx]
+ sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
+
+ # Clear GPU memory after each frame
+ del processed, video_features, text_features
+ if state_features is not None:
+ del state_features
+
+ torch.cuda.empty_cache()
+
+ # Interpolate predictions back to per-frame arrays for smooth visualization.
+ if stride > 1:
+ all_local = np.arange(num_frames)
+ for scheme in schemes_to_viz:
+ sd = scheme_data[scheme]
+
+ valid = np.isfinite(sd["viz_progress"])
+ valid_idx = np.where(valid)[0]
+ if valid_idx.size >= 1:
+ sd["viz_progress"] = interpolate_progress(
+ valid_idx, sd["viz_progress"][valid_idx], all_local
+ )
+
+ stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
+ for s in range(sd["num_stages"]):
+ stage_interp[:, s] = interpolate_progress(
+ valid_idx, sd["viz_stages"][valid_idx, s], all_local
+ )
+
+ stage_interp = np.clip(stage_interp, 0.0, 1.0)
+ row_sums = stage_interp.sum(axis=1, keepdims=True)
+ nz = row_sums.squeeze(-1) > 0
+ stage_interp[nz] = stage_interp[nz] / row_sums[nz]
+ sd["viz_stages"] = stage_interp
+ else:
+ # No valid points: keep NaNs/zeros; visualization will be empty.
+ sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
+
+ # Generate visualization for each head
+ ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
+ for scheme in schemes_to_viz:
+ sd = scheme_data[scheme]
+ stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
+ viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
+
+ visualize_episode(
+ frames=np.array(ordered_viz_frames),
+ progress_preds=sd["viz_progress"],
+ stage_preds=sd["viz_stages"],
+ title=f"{task} (Episode {episode_idx})",
+ output_path=viz_path,
+ stage_labels=stage_labels,
+ gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
+ gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
+ )
+
+ # Clear memory between episodes
+ torch.cuda.empty_cache()
+
+ logging.info(f"Visualizations saved to: {output_dir.absolute()}")
+
+
+def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
+ """Generate all frame indices, ordered by offset for cache-friendly access.
+
+ Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
+ This groups frames that share similar temporal windows together.
+ """
+ num_frames = ep_end - ep_start
+ indices = []
+ for offset in range(frame_gap):
+ for frame_rel in range(offset, num_frames, frame_gap):
+ indices.append(ep_start + frame_rel)
+ return indices
+
+
+def interpolate_progress(
+ computed_indices: np.ndarray,
+ computed_values: np.ndarray,
+ all_indices: np.ndarray,
+) -> np.ndarray:
+ """Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
+ computed_indices = np.asarray(computed_indices)
+ computed_values = np.asarray(computed_values)
+ all_indices = np.asarray(all_indices)
+
+ mask = np.isfinite(computed_values)
+ if mask.sum() == 0:
+ return np.full(all_indices.shape, np.nan, dtype=np.float32)
+ if mask.sum() == 1:
+ return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
+
+ out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
+ return out.astype(np.float32)
+
+
+def compute_sarm_progress(
+ dataset_repo_id: str,
+ reward_model_path: str,
+ output_path: str | None = None,
+ head_mode: str = "sparse",
+ device: str = "cuda",
+ num_visualizations: int = 5,
+ output_dir: str = "./sarm_viz",
+ stride: int = 1,
+):
+ """
+ Compute SARM progress predictions for all frames in a dataset.
+
+ Args:
+ dataset_repo_id: HuggingFace dataset repo ID or local path
+ reward_model_path: Path to pretrained SARM model
+ output_path: Path to save results. If None, saves to dataset's cache directory
+ head_mode: SARM head to use ("sparse", "dense", or "both")
+ device: Device to use for inference
+ num_visualizations: Number of episodes to visualize (0 to skip)
+ output_dir: Directory to save visualizations
+ stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
+ """
+ dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
+
+ # Set preprocessor to eval mode to disable augmentations
+ if hasattr(preprocess, "eval"):
+ preprocess.eval()
+ for step in preprocess.steps:
+ if hasattr(step, "eval"):
+ step.eval()
+
+ image_key = reward_model.config.image_key
+ state_key = reward_model.config.state_key
+ frame_gap = reward_model.config.frame_gap
+ num_episodes = dataset.num_episodes
+ total_frames = dataset.num_frames
+ logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
+
+ # Determine which heads to compute
+ dual_mode = reward_model.config.uses_dual_heads
+ compute_sparse = head_mode in ("sparse", "both") or not dual_mode
+ compute_dense = head_mode in ("dense", "both") and dual_mode
+
+ # Storage arrays
+ all_indices = []
+ all_episode_indices = []
+ all_frame_indices = []
+ all_progress_sparse = [] if compute_sparse else None
+ all_progress_dense = [] if compute_dense else None
+
+ if stride > 1:
+ logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
+
+ # Process all episodes
+ for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
+ ep = dataset.meta.episodes[episode_idx]
+ ep_start = ep["dataset_from_index"]
+ ep_end = ep["dataset_to_index"]
+
+ # Get task description
+ task = dataset[ep_start].get("task", "perform the task")
+
+ # Generate frames to compute (with stride applied)
+ all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
+ if stride > 1:
+ # Only compute every stride-th frame (relative to episode start)
+ compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
+ # Always include last frame for better interpolation at episode end
+ last_frame = ep_end - 1
+ if last_frame not in compute_indices:
+ compute_indices.append(last_frame)
+ compute_indices = sorted(set(compute_indices))
+ else:
+ compute_indices = all_ep_indices
+
+ center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
+
+ # Dictionary to collect results
+ frame_results = {}
+
+ for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
+ try:
+ sample = dataset[query_idx]
+
+ batch = {
+ image_key: sample[image_key],
+ "task": task,
+ "index": query_idx,
+ "episode_index": episode_idx,
+ }
+ if state_key in sample:
+ batch[state_key] = sample[state_key]
+
+ with torch.no_grad():
+ processed = preprocess(batch)
+ video_features = processed["video_features"].to(device)
+ text_features = processed["text_features"].to(device)
+ state_features = processed.get("state_features")
+ if state_features is not None:
+ state_features = state_features.to(device)
+ lengths = processed.get("lengths")
+
+ sparse_val = np.nan
+ dense_val = np.nan
+
+ # Compute sparse prediction for center frame
+ if compute_sparse:
+ sparse_progress = reward_model.calculate_rewards(
+ text_embeddings=text_features,
+ video_embeddings=video_features,
+ state_features=state_features,
+ lengths=lengths,
+ return_all_frames=True,
+ head_mode="sparse",
+ )
+ sparse_val = float(
+ sparse_progress[0, center_idx]
+ if sparse_progress.ndim == 2
+ else sparse_progress[center_idx]
+ )
+
+ # Compute dense prediction for center frame
+ if compute_dense:
+ dense_progress = reward_model.calculate_rewards(
+ text_embeddings=text_features,
+ video_embeddings=video_features,
+ state_features=state_features,
+ lengths=lengths,
+ return_all_frames=True,
+ head_mode="dense",
+ )
+ dense_val = float(
+ dense_progress[0, center_idx]
+ if dense_progress.ndim == 2
+ else dense_progress[center_idx]
+ )
+
+ frame_results[query_idx] = (sparse_val, dense_val)
+
+ except Exception as e:
+ logging.warning(f"Failed to process frame {query_idx}: {e}")
+
+ # Interpolate to get values for all frames
+ computed_indices = np.array(sorted(frame_results.keys()))
+ computed_sparse = (
+ np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
+ )
+ computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
+
+ # All frame indices for this episode
+ all_frame_idx_array = np.arange(ep_start, ep_end)
+
+ if stride > 1 and len(computed_indices) > 1:
+ # Interpolate progress values
+ if compute_sparse:
+ interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
+ if compute_dense:
+ interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
+ else:
+ # No interpolation needed
+ interp_sparse = computed_sparse if compute_sparse else None
+ interp_dense = computed_dense if compute_dense else None
+
+ # Store results for all frames
+ for i, frame_idx in enumerate(all_frame_idx_array):
+ local_idx = frame_idx - ep_start
+ all_indices.append(frame_idx)
+ all_episode_indices.append(episode_idx)
+ all_frame_indices.append(local_idx)
+ if compute_sparse:
+ if stride > 1 and len(computed_indices) > 1:
+ all_progress_sparse.append(float(interp_sparse[i]))
+ elif frame_idx in frame_results:
+ all_progress_sparse.append(frame_results[frame_idx][0])
+ else:
+ all_progress_sparse.append(np.nan)
+ if compute_dense:
+ if stride > 1 and len(computed_indices) > 1:
+ all_progress_dense.append(float(interp_dense[i]))
+ elif frame_idx in frame_results:
+ all_progress_dense.append(frame_results[frame_idx][1])
+ else:
+ all_progress_dense.append(np.nan)
+
+ # Create output table
+ table_data = {
+ "index": np.array(all_indices, dtype=np.int64),
+ "episode_index": np.array(all_episode_indices, dtype=np.int64),
+ "frame_index": np.array(all_frame_indices, dtype=np.int64),
+ }
+ if compute_sparse:
+ table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
+ if compute_dense:
+ table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
+
+ # Sort by index
+ df = pa.table(table_data).to_pandas()
+ df = df.sort_values("index").reset_index(drop=True)
+ final_table = pa.Table.from_pandas(df, preserve_index=False)
+
+ # Add metadata with reward model path
+ metadata = {b"reward_model_path": reward_model_path.encode()}
+ final_table = final_table.replace_schema_metadata(metadata)
+
+ # Determine output path
+ output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
+
+ # Save
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ pq.write_table(final_table, output_path)
+ logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
+
+ # Print statistics
+ if "progress_sparse" in df.columns:
+ valid = df["progress_sparse"].dropna()
+ logging.info(
+ f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
+ f"min={valid.min():.4f}, max={valid.max():.4f}"
+ )
+
+ if "progress_dense" in df.columns:
+ valid = df["progress_dense"].dropna()
+ logging.info(
+ f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
+ f"min={valid.min():.4f}, max={valid.max():.4f}"
+ )
+
+ # Visualize episodes after processing
+ if num_visualizations > 0:
+ viz_episodes = list(range(min(num_visualizations, num_episodes)))
+ logging.info(f"Generating {len(viz_episodes)} visualizations...")
+ visualize_sarm_predictions(
+ dataset=dataset,
+ reward_model=reward_model,
+ preprocess=preprocess,
+ episode_indices=viz_episodes,
+ head_mode=head_mode,
+ output_dir=Path(output_dir),
+ stride=stride,
+ )
+
+ return output_path
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Full RA-BC computation with visualizations
+ python src/lerobot/policies/sarm/compute_rabc_weights.py \\
+ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
+ --reward-model-path pepijn223/sarm_single_uni4
+
+ # Visualize predictions only (no RA-BC computation)
+ python src/lerobot/policies/sarm/compute_rabc_weights.py \\
+ --dataset-repo-id lerobot/aloha_sim_insertion_human \\
+ --reward-model-path pepijn223/sarm_single_uni4 \\
+ --visualize-only \\
+ --num-visualizations 10
+ """,
+ )
+ parser.add_argument(
+ "--dataset-repo-id",
+ type=str,
+ required=True,
+ help="HuggingFace dataset repo ID or local path",
+ )
+ parser.add_argument(
+ "--reward-model-path",
+ type=str,
+ default=None,
+ help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
+ )
+ parser.add_argument(
+ "--output-path",
+ type=str,
+ default=None,
+ help="Output path for parquet. If not set, saves to dataset's cache directory",
+ )
+ parser.add_argument(
+ "--head-mode",
+ type=str,
+ default="sparse",
+ choices=["sparse", "dense", "both"],
+ help="SARM head to use (default: sparse)",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use (default: cuda)",
+ )
+ # Visualization options
+ parser.add_argument(
+ "--visualize-only",
+ action="store_true",
+ help="Only visualize SARM predictions (no RA-BC computation)",
+ )
+ parser.add_argument(
+ "--num-visualizations",
+ type=int,
+ default=5,
+ help="Number of episodes to visualize (default: 5, set to 0 to skip)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="./sarm_viz",
+ help="Output directory for visualizations (default: ./sarm_viz)",
+ )
+ parser.add_argument(
+ "--push-to-hub",
+ action="store_true",
+ help="Upload progress file to the dataset repo on HuggingFace Hub",
+ default=True,
+ )
+ parser.add_argument(
+ "--stride",
+ type=int,
+ default=1,
+ help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
+ )
+
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
+
+ # Try to get reward_model_path from parquet metadata if not provided
+ reward_model_path = args.reward_model_path
+ if reward_model_path is None:
+ # Load dataset to find parquet path
+ temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
+ parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
+ reward_model_path = get_reward_model_path_from_parquet(parquet_path)
+ if reward_model_path:
+ logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
+ else:
+ raise ValueError(
+ "--reward-model-path is required (no existing parquet with model metadata found)"
+ )
+
+ # Handle visualize-only mode
+ if args.visualize_only:
+ dataset, reward_model, preprocess = load_sarm_resources(
+ args.dataset_repo_id, reward_model_path, args.device
+ )
+ logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
+ viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
+ visualize_sarm_predictions(
+ dataset=dataset,
+ reward_model=reward_model,
+ preprocess=preprocess,
+ episode_indices=viz_episodes,
+ head_mode=args.head_mode,
+ output_dir=Path(args.output_dir),
+ stride=args.stride,
+ )
+ print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
+ return
+
+ # Full RABC computation (compute_sarm_progress loads model/dataset itself)
+ output_path = compute_sarm_progress(
+ dataset_repo_id=args.dataset_repo_id,
+ reward_model_path=reward_model_path,
+ output_path=args.output_path,
+ head_mode=args.head_mode,
+ device=args.device,
+ num_visualizations=args.num_visualizations,
+ output_dir=args.output_dir,
+ stride=args.stride,
+ )
+
+ print(f"\nSARM progress values saved to: {output_path}")
+
+ # Upload to Hub if requested
+ if args.push_to_hub:
+ from huggingface_hub import HfApi
+
+ api = HfApi()
+ hub_path = "sarm_progress.parquet"
+
+ print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
+ api.upload_file(
+ path_or_fileobj=str(output_path),
+ path_in_repo=hub_path,
+ repo_id=args.dataset_repo_id,
+ repo_type="dataset",
+ )
+ print(
+ f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
+ )
+
+ print("\nTo use in training, add to your config:")
+ print(" use_rabc: true")
+ print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
+ print(" rabc_head_mode: sparse # or dense")
+ else:
+ print("\nTo use in training, add to your config:")
+ print(" use_rabc: true")
+ print(f" rabc_progress_path: {output_path}")
+ print(" rabc_head_mode: sparse # or dense")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py
new file mode 100644
index 000000000..59cb352d5
--- /dev/null
+++ b/src/lerobot/policies/sarm/configuration_sarm.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
+Paper: https://arxiv.org/abs/2509.25358
+"""
+
+from dataclasses import dataclass, field
+
+from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.optim.optimizers import AdamWConfig
+from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+
+
+@PreTrainedConfig.register_subclass("sarm")
+@dataclass
+class SARMConfig(PreTrainedConfig):
+ """Configuration class for SARM (Stage-Aware Reward Modeling).
+
+ Supports three annotation modes:
+
+ 1. single_stage (default): No annotations needed. Uses the episode's task description
+ as a single stage covering the entire episode.
+
+ 2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
+ single sparse "task" stage covering the full episode. The dense head learns detailed
+ subtask progression while sparse provides overall task completion.
+
+ 3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
+ annotations from VLM. Both heads are trained on their respective annotations.
+
+ The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
+ are loaded/generated during model initialization.
+ """
+
+ annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
+ n_obs_steps: int = 8 # Number of observation history steps
+ frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
+ max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
+
+ # Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
+ # During training with rewind: [obs_frames] + [rewind_frames]
+ # During inference: [obs_frames] only
+
+ # Architecture params
+ image_dim: int = 512
+ text_dim: int = 512
+ hidden_dim: int = 768
+ num_heads: int = 12
+ num_layers: int = 8
+ max_state_dim: int = 32
+ drop_n_last_frames: int = 1
+ batch_size: int = 64
+ clip_batch_size: int = 64
+ dropout: float = 0.1
+ stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
+
+ rewind_probability: float = 0.8
+ language_perturbation_probability: float = 0.2
+
+ # Sparse annotations (high-level stages)
+ num_sparse_stages: int = 1
+ sparse_subtask_names: list | None = None
+ sparse_temporal_proportions: list | None = None
+
+ # Dense annotations (fine-grained stages)
+ num_dense_stages: int | None = None
+ dense_subtask_names: list | None = None
+ dense_temporal_proportions: list | None = None
+
+ pretrained_model_path: str | None = None
+ device: str | None = None
+ image_key: str = "observation.images.top" # Key for image used from the dataset
+ state_key: str = "observation.state"
+
+ # Populated by the processor (video_features, state_features, text_features)
+ input_features: dict = field(default_factory=lambda: {})
+
+ # Output features (updated in __post_init__)
+ output_features: dict = field(
+ default_factory=lambda: {
+ "stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
+ "progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
+ }
+ )
+
+ normalization_mapping: dict[str, NormalizationMode] = field(
+ default_factory=lambda: {
+ "VISUAL": NormalizationMode.IDENTITY,
+ "STATE": NormalizationMode.MEAN_STD,
+ "LANGUAGE": NormalizationMode.IDENTITY,
+ "REWARD": NormalizationMode.IDENTITY,
+ }
+ )
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
+ raise ValueError(
+ f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
+ )
+
+ if self.annotation_mode == "single_stage":
+ # Use task description as stage name, full episode as one stage
+ self.num_sparse_stages = 1
+ self.sparse_subtask_names = ["task"]
+ self.sparse_temporal_proportions = [1.0]
+ self.num_dense_stages = None
+ self.dense_subtask_names = None
+ self.dense_temporal_proportions = None
+
+ elif self.annotation_mode == "dense_only":
+ self.num_sparse_stages = 1
+ self.sparse_subtask_names = ["task"]
+ self.sparse_temporal_proportions = [1.0]
+
+ self.input_features = {}
+ self.output_features = {}
+
+ if self.image_key:
+ self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
+
+ self.input_features[self.state_key] = PolicyFeature(
+ shape=(self.max_state_dim,),
+ type=FeatureType.STATE,
+ )
+
+ # Update output features based on annotation_mode
+ if self.annotation_mode in ["dense_only", "dual"]:
+ self.output_features["sparse_stage"] = PolicyFeature(
+ shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
+ )
+ self.output_features["sparse_progress"] = PolicyFeature(
+ shape=(self.num_frames, 1), type=FeatureType.REWARD
+ )
+ dense_stages = self.num_dense_stages or self.num_sparse_stages
+ self.output_features["dense_stage"] = PolicyFeature(
+ shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
+ )
+ self.output_features["dense_progress"] = PolicyFeature(
+ shape=(self.num_frames, 1), type=FeatureType.REWARD
+ )
+ else:
+ self.output_features["sparse_stage"] = PolicyFeature(
+ shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
+ )
+ self.output_features["sparse_progress"] = PolicyFeature(
+ shape=(self.num_frames, 1), type=FeatureType.REWARD
+ )
+
+ if self.max_rewind_steps >= self.n_obs_steps:
+ raise ValueError(
+ f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
+ )
+ if self.num_sparse_stages < 1:
+ raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
+ if (
+ self.annotation_mode in ["dense_only", "dual"]
+ and self.num_dense_stages is not None
+ and self.num_dense_stages < 2
+ ):
+ raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
+
+ def get_optimizer_preset(self) -> AdamWConfig:
+ """Get default optimizer configuration for SARM training."""
+ return AdamWConfig(
+ lr=5e-5,
+ weight_decay=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ )
+
+ def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
+ """Get default learning rate scheduler configuration."""
+ return CosineDecayWithWarmupSchedulerConfig(
+ peak_lr=5e-5,
+ decay_lr=5e-6,
+ num_warmup_steps=500,
+ num_decay_steps=50000,
+ )
+
+ def validate_features(self) -> None:
+ pass
+
+ @property
+ def uses_dual_heads(self) -> bool:
+ """Whether the model uses dual heads (dense_only or dual annotation modes)."""
+ return self.annotation_mode in ["dense_only", "dual"]
+
+ @property
+ def num_frames(self) -> int:
+ """Total number of frames in sequence.
+
+ For training: 1 + n_obs_steps + max_rewind_steps
+ The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
+ """
+ return 1 + self.n_obs_steps + self.max_rewind_steps
+
+ @property
+ def max_length(self) -> int:
+ return self.num_frames
+
+ @property
+ def observation_delta_indices(self) -> list[int]:
+ """Bidirectional frame sampling centered on target frame.
+
+ Example with n_obs_steps=8, gap=30:
+ Before: [-120, -90, -60, -30] (4 frames)
+ Current: [0] (1 frame)
+ After: [30, 60, 90, 120] (4 frames)
+ Total: 9 frames
+ """
+ half_steps = self.n_obs_steps // 2
+
+ past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
+ future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
+ obs_deltas = past_deltas + [0] + future_deltas
+
+ # Rewind placeholders
+ rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
+
+ return obs_deltas + rewind_deltas
+
+ @property
+ def action_delta_indices(self) -> None:
+ """SARM is a reward model, not an action policy."""
+ return None
+
+ @property
+ def reward_delta_indices(self) -> None:
+ return None
diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py
new file mode 100644
index 000000000..a88b2ad64
--- /dev/null
+++ b/src/lerobot/policies/sarm/modeling_sarm.py
@@ -0,0 +1,793 @@
+#!/usr/bin/env python
+
+# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
+# and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
+
+Paper: https://arxiv.org/abs/2509.25358
+
+- StageTransformer: Predicts stage classification (sparse/dense)
+- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
+"""
+
+import json
+import logging
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F # noqa: N812
+from torch import Tensor
+
+from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies.sarm.configuration_sarm import SARMConfig
+from lerobot.policies.sarm.sarm_utils import (
+ normalize_stage_tau,
+ pad_state_to_max_dim,
+)
+
+
+class StageTransformer(nn.Module):
+ """
+ Stage classification transformer for SARM.
+
+ Predicts which stage/subtask the current frame belongs to.
+ Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
+
+ Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
+ Output: stage logits (B, T, num_classes)
+ """
+
+ def __init__(
+ self,
+ d_model: int = 512,
+ vis_emb_dim: int = 512,
+ text_emb_dim: int = 512,
+ state_dim: int = 32,
+ n_layers: int = 6,
+ n_heads: int = 8,
+ dropout: float = 0.1,
+ num_cameras: int = 1,
+ num_classes_sparse: int = 4,
+ num_classes_dense: int = 8,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.num_cameras = num_cameras
+
+ # Projections
+ self.lang_proj = nn.Linear(text_emb_dim, d_model)
+ self.visual_proj = nn.Linear(vis_emb_dim, d_model)
+ self.state_proj = nn.Linear(state_dim, d_model)
+
+ # Encoder
+ enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
+ self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
+
+ # Positional bias on first visual frame
+ self.first_pos = nn.Parameter(torch.zeros(1, d_model))
+
+ # Shared fusion MLP
+ # Fuses (num_cameras + 2) streams: cameras + lang + state
+ fused_in = d_model * (num_cameras + 2)
+ self.fusion_backbone = nn.Sequential(
+ nn.LayerNorm(fused_in),
+ nn.Linear(fused_in, d_model),
+ nn.ReLU(),
+ )
+
+ # Scheme-specific heads
+ self.heads = nn.ModuleDict(
+ {
+ "sparse": nn.Linear(d_model, num_classes_sparse),
+ "dense": nn.Linear(d_model, num_classes_dense),
+ }
+ )
+
+ def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
+ """
+ Prepare language embeddings for fusion.
+
+ Accepts lang_emb of shape:
+ - (B, text_emb_dim) -> broadcast across time
+ - (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
+
+ Returns: (B, 1, T, D)
+ """
+ if lang_emb.dim() == 3:
+ # (B, T, E) -> (B, T, D) -> (B, 1, T, D)
+ lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
+ else:
+ # (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
+ lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
+ return lang_proj
+
+ def forward(
+ self,
+ img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
+ lang_emb: torch.Tensor, # (B, E) or (B, T, E)
+ state: torch.Tensor, # (B, T, state_dim)
+ lengths: torch.Tensor, # (B,) - valid sequence lengths
+ scheme: str = "sparse", # "sparse" or "dense"
+ ) -> torch.Tensor:
+ """
+ Forward pass for stage classification.
+
+ Args:
+ img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
+ lang_emb: Language embeddings (B, E) or (B, T, E) for dense
+ state: State features (B, T, state_dim)
+ lengths: Valid sequence lengths (B,) for masking
+ scheme: "sparse" or "dense" for head selection
+
+ Returns:
+ Stage logits (B, T, num_classes)
+ """
+ assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
+
+ B, N, T, _ = img_seq.shape # noqa: N806
+ D = self.d_model # noqa: N806
+ device = img_seq.device
+
+ # Project inputs
+ vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
+ state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
+ lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
+
+ # Concatenate streams
+ # cameras + lang + state -> (B, N+2, T, D)
+ x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
+
+ # Add positional bias to first visual frame
+ x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
+
+ # Flatten to tokens for Transformer
+ x_tokens = x.view(B, (N + 2) * T, D)
+ L = x_tokens.size(1) # noqa: N806
+
+ # Create padding mask
+ base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
+ mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
+
+ # Create causal mask
+ causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
+
+ # Encode
+ h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
+
+ # Reshape and fuse
+ h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
+ fused = self.fusion_backbone(h) # (B, T, D)
+
+ # Scheme-specific logits
+ logits = self.heads[scheme](fused) # (B, T, num_classes)
+ return logits
+
+
+class SubtaskTransformer(nn.Module):
+ """
+ Subtask progress regression transformer for SARM.
+
+ Predicts within-stage normalized progress (tau) conditioned on stage prior.
+ The stage prior is a one-hot encoding passed from StageTransformer predictions.
+
+ Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
+ Output: tau predictions (B, T) in [0, 1]
+ """
+
+ def __init__(
+ self,
+ d_model: int = 512,
+ vis_emb_dim: int = 512,
+ text_emb_dim: int = 512,
+ state_dim: int = 32,
+ n_layers: int = 6,
+ n_heads: int = 8,
+ dropout: float = 0.1,
+ num_cameras: int = 1,
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.num_cameras = num_cameras
+
+ # Projections
+ self.lang_proj = nn.Linear(text_emb_dim, d_model)
+ self.visual_proj = nn.Linear(vis_emb_dim, d_model)
+ self.state_proj = nn.Linear(state_dim, d_model)
+
+ # Encoder
+ enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
+ self.transformer = nn.TransformerEncoder(enc, n_layers)
+
+ # Learned bias on first visual frame
+ self.first_pos = nn.Parameter(torch.zeros(1, d_model))
+
+ # Shared fusion backbone
+ # Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
+ fused_in = d_model * (num_cameras + 3)
+ self.fusion_backbone = nn.Sequential(
+ nn.LayerNorm(fused_in),
+ nn.Linear(fused_in, d_model),
+ nn.ReLU(),
+ )
+
+ # Scheme-specific regression heads
+ self.heads = nn.ModuleDict(
+ {
+ "sparse": nn.Linear(d_model, 1),
+ "dense": nn.Linear(d_model, 1),
+ }
+ )
+
+ def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
+ """
+ Prepare language embeddings for fusion.
+ """
+ if lang_emb.dim() == 3:
+ # (B, T, E) -> (B, T, D) -> (B, 1, T, D)
+ return self.lang_proj(lang_emb).unsqueeze(1)
+ else:
+ # (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
+ return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
+
+ def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
+ """
+ Deterministic projection of one-hot stage to d_model by pad/truncate.
+
+ Args:
+ stage_prior: One-hot stage embedding (B, 1, T, C)
+
+ Returns:
+ Projected stage embedding (B, 1, T, d_model)
+ """
+ B, one, T, C = stage_prior.shape # noqa: N806
+ D = self.d_model # noqa: N806
+ if D == C:
+ return stage_prior
+ elif D > C:
+ pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
+ return torch.cat([stage_prior, pad], dim=-1)
+ else:
+ return stage_prior[..., :D]
+
+ def forward(
+ self,
+ img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
+ lang_emb: torch.Tensor, # (B, E) or (B, T, E)
+ state: torch.Tensor, # (B, T, state_dim)
+ lengths: torch.Tensor, # (B,) - valid sequence lengths
+ stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
+ scheme: str = "sparse", # "sparse" or "dense"
+ ) -> torch.Tensor:
+ """
+ Forward pass for subtask progress regression.
+
+ Args:
+ img_seq: Image embeddings (B, N, T, vis_emb_dim)
+ lang_emb: Language embeddings (B, E) or (B, T, E)
+ state: State features (B, T, state_dim)
+ lengths: Valid sequence lengths (B,) for masking
+ stage_prior: One-hot stage prior (B, 1, T, num_classes)
+ scheme: "sparse" or "dense" for head selection
+
+ Returns:
+ Tau predictions (B, T) in [0, 1] via sigmoid
+ """
+ assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
+
+ B, N, T, _ = img_seq.shape # noqa: N806
+ D = self.d_model # noqa: N806
+ device = img_seq.device
+
+ # Project inputs
+ vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
+ state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
+ lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
+ stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
+
+ # Concatenate all streams
+ # cameras + lang + state + stage_emb -> (B, N+3, T, D)
+ x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
+
+ # Add positional bias to first visual frame
+ x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
+
+ # Flatten to tokens
+ x_tokens = x.view(B, (N + 3) * T, D)
+ L = x_tokens.size(1) # noqa: N806
+
+ # Create padding mask
+ base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
+ mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
+
+ # Create causal mask
+ causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
+
+ # Encode
+ h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
+
+ # Reshape and fuse
+ h = h.view(B, N + 3, T, D)
+ h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
+ fused = self.fusion_backbone(h_flat) # (B, T, D)
+
+ # Scheme-specific regression head -> sigmoid
+ r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
+ return r
+
+
+def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Generate one-hot stage embeddings from targets.
+
+ Args:
+ num_classes: Number of stage classes
+ targets: Target values (B, T) where integer part is stage index
+
+ Returns:
+ One-hot stage embedding (B, 1, T, num_classes)
+ """
+ # Integer part of float targets -> [0, C-1]
+ idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
+ C = num_classes # noqa: N806
+ # Identity-lookup one-hot
+ stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
+ stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
+ return stage_onehot
+
+
+class SARMRewardModel(PreTrainedPolicy):
+ """
+ SARM Reward Model for stage-aware task completion rewards.
+
+ Uses two separate transformer models:
+ - StageTransformer: Classifies which stage/subtask
+ - SubtaskTransformer: Predicts within-stage progress (tau)
+
+ Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
+ """
+
+ name = "sarm"
+ config_class = SARMConfig
+
+ def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
+ super().__init__(config, dataset_stats)
+ config.validate_features()
+ self.config = config
+ self.dataset_stats = dataset_stats
+ self.device = torch.device(
+ config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
+ )
+
+ # Load temporal proportions based on annotation_mode
+ if config.annotation_mode == "single_stage":
+ logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
+ elif dataset_meta is not None:
+ self._load_temporal_proportions(dataset_meta)
+
+ # Create two separate models
+ self.stage_model = StageTransformer(
+ d_model=config.hidden_dim,
+ vis_emb_dim=config.image_dim,
+ text_emb_dim=config.text_dim,
+ state_dim=config.max_state_dim,
+ n_layers=config.num_layers,
+ n_heads=config.num_heads,
+ dropout=config.dropout,
+ num_cameras=1, # Single camera for now
+ num_classes_sparse=config.num_sparse_stages,
+ num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
+ )
+
+ self.subtask_model = SubtaskTransformer(
+ d_model=config.hidden_dim,
+ vis_emb_dim=config.image_dim,
+ text_emb_dim=config.text_dim,
+ state_dim=config.max_state_dim,
+ n_layers=config.num_layers,
+ n_heads=config.num_heads,
+ dropout=config.dropout,
+ num_cameras=1,
+ )
+
+ self.stage_model.to(self.device)
+ self.subtask_model.to(self.device)
+
+ # GT/predicted stage ratio for teacher forcing
+ self.gt_stage_ratio = 0.75
+
+ if config.uses_dual_heads:
+ logging.info(
+ f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
+ f"{config.num_dense_stages} dense stages"
+ )
+ else:
+ logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
+
+ logging.info(f"SARM initialized on {self.device}")
+
+ def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
+ """Load temporal proportions from a JSON file (preserving order)."""
+ if not path.exists():
+ raise ValueError(
+ f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
+ f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
+ )
+ with open(path) as f:
+ proportions_dict = json.load(f)
+ names = list(proportions_dict.keys())
+ logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
+ logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
+ return names, [proportions_dict[name] for name in names]
+
+ def _load_temporal_proportions(self, dataset_meta) -> None:
+ """Load temporal proportions based on annotation_mode."""
+ meta_path = dataset_meta.root / "meta"
+
+ if self.config.annotation_mode == "dual":
+ names, props = self._load_proportions_from_json(
+ meta_path / "temporal_proportions_sparse.json", "sparse"
+ )
+ (
+ self.config.num_sparse_stages,
+ self.config.sparse_subtask_names,
+ self.config.sparse_temporal_proportions,
+ ) = len(names), names, props
+
+ if self.config.annotation_mode in ["dense_only", "dual"]:
+ names, props = self._load_proportions_from_json(
+ meta_path / "temporal_proportions_dense.json", "dense"
+ )
+ (
+ self.config.num_dense_stages,
+ self.config.dense_subtask_names,
+ self.config.dense_temporal_proportions,
+ ) = len(names), names, props
+ if self.config.annotation_mode == "dense_only":
+ logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
+
+ def to(self, device):
+ """Override to method to ensure all components move together."""
+ super().to(device)
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
+ self.stage_model.to(device)
+ self.subtask_model.to(device)
+ return self
+
+ @torch.no_grad()
+ def calculate_rewards(
+ self,
+ text_embeddings: np.ndarray | torch.Tensor,
+ video_embeddings: np.ndarray | torch.Tensor,
+ state_features: np.ndarray | torch.Tensor | None = None,
+ lengths: np.ndarray | torch.Tensor | None = None,
+ return_all_frames: bool = False,
+ return_stages: bool = False,
+ return_confidence: bool = False,
+ head_mode: str | None = "sparse",
+ frame_index: int | None = None,
+ ) -> np.ndarray | tuple:
+ """
+ Calculate rewards for given text, video, and state representations.
+
+ This is the canonical method for SARM reward computation, used for:
+ - Inference/visualization
+ - RA-BC weight computation
+
+ Args:
+ text_embeddings: Encoded text representations (batch_size, 512)
+ video_embeddings: Encoded video representations (batch_size, num_frames, 512)
+ state_features: Joint state features (batch_size, num_frames, state_dim)
+ lengths: Valid sequence lengths (batch_size,)
+ return_all_frames: If True, return rewards for all frames
+ return_stages: If True, also return stage predictions
+ return_confidence: If True, also return stage confidence
+ head_mode: Which head to use ("sparse" or "dense")
+ frame_index: Index of the target frame to extract (default: n_obs_steps).
+
+ Returns:
+ Rewards and optionally stage probs/confidence.
+ """
+ if isinstance(text_embeddings, np.ndarray):
+ text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
+ if isinstance(video_embeddings, np.ndarray):
+ video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
+ if state_features is not None and isinstance(state_features, np.ndarray):
+ state_features = torch.tensor(state_features, dtype=torch.float32)
+
+ # Handle single sample case
+ if text_embeddings.dim() == 1:
+ text_embeddings = text_embeddings.unsqueeze(0)
+ video_embeddings = video_embeddings.unsqueeze(0)
+ if state_features is not None:
+ state_features = state_features.unsqueeze(0)
+ single_sample = True
+ else:
+ single_sample = False
+
+ batch_size = video_embeddings.shape[0]
+ seq_len = video_embeddings.shape[1]
+
+ scheme = head_mode
+
+ # Default lengths if not provided
+ if lengths is None:
+ lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
+ elif isinstance(lengths, np.ndarray):
+ lengths = torch.tensor(lengths, dtype=torch.int32)
+
+ # Reshape video to (B, N, T, D) for multi-camera format
+ # Currently single camera: (B, T, D) -> (B, 1, T, D)
+ img_seq = video_embeddings.unsqueeze(1).to(self.device)
+ lang_emb = text_embeddings.to(self.device)
+ state = (
+ state_features.to(self.device)
+ if state_features is not None
+ else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
+ )
+ lens = lengths.to(self.device)
+
+ # Pad state to max_state_dim
+ state = pad_state_to_max_dim(state, self.config.max_state_dim)
+
+ # Get num_classes for this scheme
+ num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
+
+ # Run stage model
+ stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
+ stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
+ stage_idx = stage_probs.argmax(dim=-1) # (B, T)
+ stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
+
+ # Create one-hot stage prior
+ stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
+ stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
+
+ # Run subtask model
+ tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
+
+ # Compute final reward: stage + tau
+ raw_reward = stage_idx.float() + tau_pred # (B, T)
+
+ # Normalize to [0, 1] using temporal proportions for proper weighting
+ if scheme == "sparse":
+ normalized_reward = normalize_stage_tau(
+ raw_reward,
+ num_stages=num_classes,
+ temporal_proportions=self.config.sparse_temporal_proportions,
+ subtask_names=self.config.sparse_subtask_names,
+ )
+ else:
+ normalized_reward = normalize_stage_tau(
+ raw_reward,
+ num_stages=num_classes,
+ temporal_proportions=self.config.dense_temporal_proportions,
+ subtask_names=self.config.dense_subtask_names,
+ )
+
+ # Default frame index is n_obs_steps (last observation frame)
+ if frame_index is None:
+ frame_index = self.config.n_obs_steps
+
+ # Prepare outputs (batch mode or no smoothing)
+ if return_all_frames:
+ rewards = normalized_reward.cpu().numpy()
+ else:
+ rewards = normalized_reward[:, frame_index].cpu().numpy()
+
+ if single_sample:
+ rewards = rewards[0] if not return_all_frames else rewards[0]
+
+ outputs = [rewards]
+ if return_stages:
+ probs = stage_probs.cpu().numpy()
+ if single_sample:
+ probs = probs[0]
+ outputs.append(probs)
+ if return_confidence:
+ conf = stage_conf.cpu().numpy()
+ if single_sample:
+ conf = conf[0]
+ outputs.append(conf)
+
+ return outputs[0] if len(outputs) == 1 else tuple(outputs)
+
+ def train(self, mode: bool = True):
+ """Set training mode for both models."""
+ super().train(mode)
+ self.stage_model.train(mode)
+ self.subtask_model.train(mode)
+ return self
+
+ def eval(self):
+ """Set evaluation mode for both models."""
+ return self.train(False)
+
+ def parameters(self):
+ """Override to return trainable parameters from both models."""
+ from itertools import chain
+
+ return chain(self.stage_model.parameters(), self.subtask_model.parameters())
+
+ def get_optim_params(self):
+ """Override to return optimizer parameters from both models."""
+ return self.parameters()
+
+ def reset(self):
+ """Required by PreTrainedPolicy but not used for reward models."""
+ pass
+
+ def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
+ """Required by PreTrainedPolicy but not used for reward models."""
+ raise NotImplementedError("SARM model does not predict action chunks")
+
+ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
+ """Required by PreTrainedPolicy but not used for SARM."""
+ raise NotImplementedError("SARM model does not select actions")
+
+ def _train_step(
+ self,
+ img_emb: torch.Tensor, # (B, N, T, D)
+ lang_emb: torch.Tensor, # (B, E) or (B, T, E)
+ state: torch.Tensor, # (B, T, state_dim)
+ lengths: torch.Tensor, # (B,)
+ targets: torch.Tensor, # (B, T) - format: stage.tau
+ scheme: str,
+ ) -> dict[str, torch.Tensor]:
+ """
+ Single training step for one annotation scheme.
+
+ Implements 75%/25% GT/predicted stage conditioning.
+
+ Args:
+ img_emb: Image embeddings (B, N, T, D)
+ lang_emb: Language embeddings
+ state: State features
+ lengths: Valid sequence lengths
+ targets: Target values where floor=stage, remainder=tau
+ scheme: "sparse" or "dense"
+
+ Returns:
+ Dict with stage_loss, subtask_loss, total_loss
+ """
+ num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
+
+ # Ground truth: stage (integer) and tau (fractional)
+ # Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
+ # where targets may exceed expected range (e.g., frames between subtasks)
+ gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
+ gt_tau = torch.remainder(targets, 1.0) # (B, T)
+
+ # Run stage model
+ stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
+
+ # 75%/25% GT/predicted stage conditioning
+ if random.random() < self.gt_stage_ratio:
+ # Mode 1: Use ground truth stage -> one-hot
+ stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
+ else:
+ # Mode 2: Use predicted stage argmax -> one-hot
+ stage_idx = stage_pred.argmax(dim=-1) # (B, T)
+ stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
+ stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
+
+ # Run subtask model with stage prior
+ tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
+
+ # Compute losses
+ stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
+ subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
+
+ return {
+ "stage_loss": stage_loss,
+ "subtask_loss": subtask_loss,
+ "total_loss": stage_loss + subtask_loss,
+ }
+
+ def forward(self, batch):
+ """
+ Forward pass for SARM reward model training.
+
+ Uses stage+tau target format where:
+ - Integer part = stage index
+ - Fractional part = within-stage progress (tau)
+
+ Training uses 75%/25% GT/predicted stage conditioning.
+
+ Args:
+ batch: Dictionary with 'observation' containing:
+ - 'video_features': (B, T, 512) pre-encoded video features
+ - 'text_features': (B, 512) or (B, T, 512) text features
+ - 'state_features': (B, T, state_dim) joint state features
+ - 'lengths': (B,) valid sequence lengths
+ - 'sparse_targets': (B, T) sparse targets (stage.tau format)
+ - 'dense_targets': (B, T) dense targets (optional, for dual mode)
+
+ Returns:
+ Tuple of (total_loss, output_dict with loss components)
+ """
+ observation = batch.get("observation", batch)
+
+ # Extract features
+ video_features = observation["video_features"].to(self.device)
+ text_features = observation["text_features"].to(self.device)
+ state_features = observation.get("state_features")
+ if state_features is not None:
+ state_features = state_features.to(self.device)
+
+ batch_size = video_features.shape[0]
+ seq_len = video_features.shape[1]
+
+ # Get lengths (default to full sequence)
+ lengths = observation.get("lengths")
+ if lengths is None:
+ lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
+ else:
+ lengths = lengths.to(self.device)
+
+ # Reshape video to (B, N, T, D) - single camera
+ img_emb = video_features.unsqueeze(1)
+
+ # Pad state to max_state_dim
+ if state_features is None:
+ state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
+ else:
+ state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
+
+ output_dict = {}
+ total_loss = torch.tensor(0.0, device=self.device)
+
+ # Sparse training (always)
+ sparse_targets = observation.get("sparse_targets")
+ if sparse_targets is None:
+ # Try legacy format
+ sparse_targets = observation.get("targets")
+ if sparse_targets is None:
+ raise ValueError("sparse_targets (or targets) is required for SARM training")
+ sparse_targets = sparse_targets.to(self.device)
+
+ sparse_result = self._train_step(
+ img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
+ )
+ output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
+ output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
+ total_loss = total_loss + sparse_result["total_loss"]
+
+ # Dense training (if dual mode)
+ if self.config.uses_dual_heads:
+ dense_targets = observation.get("dense_targets")
+ if dense_targets is not None:
+ dense_targets = dense_targets.to(self.device)
+ dense_result = self._train_step(
+ img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
+ )
+ output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
+ output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
+ total_loss = total_loss + dense_result["total_loss"]
+
+ output_dict["total_loss"] = total_loss.item()
+ return total_loss, output_dict
+
+
+def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
+ """Compute cross-entropy loss for stage classification."""
+ _, _, num_stages = stage_logits.shape
+ stage_logits_flat = stage_logits.reshape(-1, num_stages)
+ # Clamp target stage indices to valid range [0, num_stages-1]
+ target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
+ return F.cross_entropy(stage_logits_flat, target_stages_flat)
diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py
new file mode 100644
index 000000000..5c617282a
--- /dev/null
+++ b/src/lerobot/policies/sarm/processor_sarm.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""SARM Processor for encoding images/text and generating stage+tau targets."""
+
+import random
+from typing import Any
+
+import numpy as np
+import pandas as pd
+import torch
+from faker import Faker
+from PIL import Image
+from transformers import CLIPModel, CLIPProcessor
+
+from lerobot.configs.types import FeatureType, PolicyFeature
+from lerobot.policies.sarm.configuration_sarm import SARMConfig
+from lerobot.policies.sarm.sarm_utils import (
+ apply_rewind_augmentation,
+ compute_absolute_indices,
+ find_stage_and_tau,
+ pad_state_to_max_dim,
+)
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyAction,
+ PolicyProcessorPipeline,
+ ProcessorStep,
+ RenameObservationsProcessorStep,
+)
+from lerobot.processor.converters import (
+ from_tensor_to_numpy,
+ policy_action_to_transition,
+ transition_to_policy_action,
+)
+from lerobot.processor.core import EnvTransition, TransitionKey
+from lerobot.processor.pipeline import PipelineFeatureType
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+
+class SARMEncodingProcessorStep(ProcessorStep):
+ """ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
+
+ def __init__(
+ self,
+ config: SARMConfig,
+ image_key: str | None = None,
+ dataset_meta=None,
+ dataset_stats: dict | None = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.image_key = image_key or config.image_key
+ self.dataset_meta = dataset_meta
+ self.dataset_stats = dataset_stats
+ self.annotation_mode = config.annotation_mode
+
+ # Helper to create temporal proportions dict
+ def make_props_dict(names, props):
+ return dict(zip(names, props, strict=True)) if names and props else None
+
+ # Sparse annotations (always needed)
+ self.sparse_temporal_proportions = make_props_dict(
+ config.sparse_subtask_names, config.sparse_temporal_proportions
+ )
+ self.sparse_subtask_names = config.sparse_subtask_names
+
+ # Dense annotations (only for dual mode)
+ self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
+ self.dense_temporal_proportions = (
+ make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
+ if config.uses_dual_heads
+ else None
+ )
+
+ self.device = torch.device(
+ self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
+ )
+
+ self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
+ self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
+ self.clip_model.to(self.device)
+ self.clip_model.eval()
+
+ self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
+ self.fake = Faker()
+
+ def _find_episode_for_frame(self, frame_idx: int) -> int:
+ """Find the episode index for a given frame index."""
+ for ep_idx in range(len(self.dataset_meta.episodes)):
+ ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
+ ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
+ if ep_start <= frame_idx < ep_end:
+ return ep_idx
+ return 0
+
+ def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
+ """Get episode indices for each frame index."""
+ if episode_index is None:
+ return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
+
+ episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
+
+ # If single episode but multiple frames, compute episode for each frame
+ if len(episode_indices) == 1 and len(frame_indices) > 1:
+ return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
+
+ return episode_indices
+
+ def _generate_perturbed_task(self) -> str:
+ """Generate a random perturbed task string for language perturbation."""
+ num_words = random.randint(1, 5)
+ verb = random.choice(self.verbs)
+ phrase = " ".join([verb] + self.fake.words(nb=num_words))
+ return phrase
+
+ def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
+ """Get global subtask names and temporal proportions for an annotation type."""
+ if annotation_type == "dense":
+ return self.dense_subtask_names, self.dense_temporal_proportions
+ return self.sparse_subtask_names, self.sparse_temporal_proportions
+
+ def _load_episode_annotations(
+ self,
+ ep_idx: int,
+ episodes_df: pd.DataFrame | None,
+ annotation_type: str,
+ global_names: list[str],
+ ) -> tuple[list | None, list | None, list | None]:
+ """Load subtask annotations for an episode from DataFrame."""
+ # Single-stage mode: (linear progress 0→1)
+ if episodes_df is None or len(global_names) == 1:
+ return None, None, None
+
+ # Resolve column name with fallback
+ def col(suffix):
+ prefixed = f"{annotation_type}_{suffix}"
+ return prefixed if prefixed in episodes_df.columns else suffix
+
+ col_names = col("subtask_names")
+ if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
+ return None, None, None
+
+ subtask_names = episodes_df.loc[ep_idx, col_names]
+ if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
+ return None, None, None
+
+ return (
+ subtask_names,
+ episodes_df.loc[ep_idx, col("subtask_start_frames")],
+ episodes_df.loc[ep_idx, col("subtask_end_frames")],
+ )
+
+ def __call__(self, transition: EnvTransition) -> EnvTransition:
+ """
+ Encode images, text, and normalize states in the transition.
+
+ Implements SARM training data preparation:
+ - Applies language perturbation (20% probability)
+ - Applies rewind augmentation (80% probability)
+ - Generates stage+tau targets for all frames
+ - Outputs lengths tensor for valid sequence masking
+ """
+ new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
+ observation = new_transition.get(TransitionKey.OBSERVATION)
+ comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
+
+ frame_index = comp_data.get("index")
+ episode_index = comp_data.get("episode_index")
+
+ if frame_index is None:
+ raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
+ if episode_index is None:
+ raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
+
+ frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
+ episode_indices = self._get_episode_indices(frame_indices, episode_index)
+
+ image = observation.get(self.image_key)
+ if isinstance(image, torch.Tensor):
+ image = image.cpu().numpy()
+
+ # If 4D (T, C, H, W) from delta_timestamps, add batch dim
+ # If 3D (C, H, W) single frame, add batch and time dims
+ if image.ndim == 4:
+ image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
+ elif image.ndim == 3:
+ image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
+
+ batch_size = image.shape[0]
+ total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
+ n_obs_steps = self.config.n_obs_steps
+ max_rewind_steps = self.config.max_rewind_steps
+ n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
+
+ # Rewind augmentation
+ rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
+ apply_rewind = self.training and random.random() < self.config.rewind_probability
+
+ if apply_rewind and self.dataset_meta is not None:
+ for b_idx, (ep_idx, frame_idx) in enumerate(
+ zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
+ ):
+ ep_idx, frame_idx = int(ep_idx), int(frame_idx)
+ ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
+
+ rewind_step, _ = apply_rewind_augmentation(
+ frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
+ )
+ rewind_steps[b_idx] = rewind_step
+
+ # Compute valid lengths: n_obs_frames + rewind_steps
+ lengths = n_obs_frames + rewind_steps # (B,)
+
+ # Apply rewind masking to images
+ # For frames beyond valid length, we mask with zeros (or copy last valid frame)
+ for b_idx in range(batch_size):
+ valid_len = lengths[b_idx].item()
+ if valid_len < total_frames:
+ image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
+
+ # Encode images with CLIP
+ video_features = self._encode_images_batch(image)
+ observation["video_features"] = video_features
+
+ state_key = self.config.state_key
+ state_data = observation.get(state_key)
+
+ if isinstance(state_data, torch.Tensor):
+ state_tensor = state_data.float()
+ else:
+ state_tensor = torch.tensor(state_data, dtype=torch.float32)
+
+ if state_tensor.ndim == 2:
+ state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
+ elif state_tensor.ndim == 1:
+ state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
+
+ # Apply same rewind masking to state
+ for b_idx in range(batch_size):
+ valid_len = lengths[b_idx].item()
+ if valid_len < state_tensor.shape[1]:
+ state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
+
+ observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
+
+ task = comp_data.get("task")
+ if isinstance(task, list):
+ task = task[0] if task else ""
+
+ # Apply language perturbation during training (20% probability)
+ # When perturbed, targets will be zeroed to train model to output low values for irrelevant text
+ apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
+ if apply_perturbation:
+ task = self._generate_perturbed_task()
+
+ # Encode text with CLIP
+ observation["text_features"] = self._encode_text_clip(task, batch_size)
+
+ # Store lengths for model
+ observation["lengths"] = lengths
+
+ # When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
+ if self.dataset_meta is not None:
+ episodes_df = None
+ if self.sparse_subtask_names != ["task"]:
+ episodes_df = self.dataset_meta.episodes.to_pandas()
+
+ # Generate sparse targets
+ if self.sparse_temporal_proportions is not None:
+ if apply_perturbation:
+ # Zero targets when language is perturbed
+ sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
+ else:
+ sparse_targets = self._compute_batch_targets(
+ frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
+ )
+ observation["sparse_targets"] = sparse_targets
+
+ # Generate dense targets (for dual mode)
+ if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
+ if apply_perturbation:
+ # Zero targets when language is perturbed
+ dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
+ else:
+ dense_targets = self._compute_batch_targets(
+ frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
+ )
+ observation["dense_targets"] = dense_targets
+
+ new_transition[TransitionKey.OBSERVATION] = observation
+ return new_transition
+
+ def _compute_batch_targets(
+ self,
+ frame_indices: np.ndarray,
+ episode_indices: np.ndarray,
+ lengths: torch.Tensor,
+ rewind_steps: torch.Tensor,
+ episodes_df: pd.DataFrame | None,
+ annotation_type: str,
+ ) -> torch.Tensor:
+ """Compute stage+tau targets for a batch of samples."""
+ batch_size = len(frame_indices)
+ n_obs_steps = self.config.n_obs_steps
+ max_rewind_steps = self.config.max_rewind_steps
+ total_frames = 1 + n_obs_steps + max_rewind_steps
+ frame_gap = self.config.frame_gap
+
+ global_names, temporal_props = self._get_annotation_config(annotation_type)
+ targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
+
+ for b_idx in range(batch_size):
+ ep_idx = int(episode_indices[b_idx])
+ frame_idx = int(frame_indices[b_idx])
+
+ ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
+ ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
+ ep_length = ep_end - ep_start
+
+ subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
+ ep_idx, episodes_df, annotation_type, global_names
+ )
+
+ # Compute observation frame indices
+ obs_indices, _ = compute_absolute_indices(
+ frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
+ )
+ obs_indices = obs_indices.tolist()
+
+ # Compute targets for observation frames
+ for t_idx, abs_idx in enumerate(obs_indices):
+ rel_frame = abs_idx - ep_start
+ targets[b_idx, t_idx] = find_stage_and_tau(
+ rel_frame,
+ ep_length,
+ subtask_names,
+ subtask_start_frames,
+ subtask_end_frames,
+ global_names,
+ temporal_props,
+ return_combined=True,
+ )
+
+ # Compute targets for rewind frames (if any)
+ rewind_step = rewind_steps[b_idx].item()
+ if rewind_step > 0:
+ _, rewind_indices = apply_rewind_augmentation(
+ frame_idx,
+ ep_start,
+ n_obs_steps,
+ max_rewind_steps,
+ frame_gap=frame_gap,
+ rewind_step=rewind_step,
+ )
+
+ for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
+ rel_frame = max(0, abs_idx - ep_start)
+ targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
+ rel_frame,
+ ep_length,
+ subtask_names,
+ subtask_start_frames,
+ subtask_end_frames,
+ global_names,
+ temporal_props,
+ return_combined=True,
+ )
+
+ return targets
+
+ @property
+ def training(self) -> bool:
+ return getattr(self, "_training_mode", True)
+
+ def train(self, mode: bool = True):
+ """Set training mode for augmentation decisions."""
+ self._training_mode = mode
+ return self
+
+ def eval(self):
+ """Set evaluation mode (disable augmentations)."""
+ return self.train(False)
+
+ @torch.no_grad()
+ def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
+ """Encode a batch of images using CLIP.
+
+ Args:
+ images: Batched images with shape: (B, T, C, H, W)
+
+ Returns:
+ Encoded feature vectors with shape (B, T, 512)
+ """
+
+ batch_size, seq_length = images.shape[0], images.shape[1]
+ images = images.reshape(batch_size * seq_length, *images.shape[2:])
+
+ num_frames = images.shape[0]
+ images_list = []
+ for i in range(num_frames):
+ img = images[i]
+ if img.shape[0] in [1, 3]: # Channel first (C, H, W)
+ img = img.transpose(1, 2, 0)
+
+ # Handle single channel
+ if img.shape[-1] == 1:
+ img = np.repeat(img, 3, axis=-1)
+
+ if img.dtype != np.uint8:
+ img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
+
+ images_list.append(Image.fromarray(img))
+
+ all_embeddings = []
+ for i in range(0, num_frames, self.config.clip_batch_size):
+ batch_imgs = images_list[i : i + self.config.clip_batch_size]
+
+ inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ # Get image embeddings
+ embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
+
+ # Handle single frame case
+ if embeddings.dim() == 1:
+ embeddings = embeddings.unsqueeze(0)
+
+ all_embeddings.append(embeddings)
+
+ all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
+ all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
+
+ return all_embeddings
+
+ @torch.no_grad()
+ def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
+ """Encode text using CLIP text encoder (per SARM paper A.4).
+
+ Args:
+ text: Task description text to encode
+ batch_size: Batch size to replicate for
+
+ Returns:
+ Encoded text features with shape (B, 512)
+ """
+ inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
+ text_embedding = text_embedding.expand(batch_size, -1)
+
+ return text_embedding
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ """Add encoded features to the observation features."""
+ features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
+ type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
+ )
+ features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
+ type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
+ )
+ features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
+ type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
+ )
+ return features
+
+
+def make_sarm_pre_post_processors(
+ config: SARMConfig,
+ dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
+ dataset_meta=None,
+) -> tuple[
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
+ PolicyProcessorPipeline[PolicyAction, PolicyAction],
+]:
+ """Create pre-processor and post-processor pipelines for SARM."""
+ return (
+ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+ steps=[
+ AddBatchDimensionProcessorStep(),
+ RenameObservationsProcessorStep(rename_map={}),
+ NormalizerProcessorStep(
+ features={**config.input_features, **config.output_features},
+ norm_map=config.normalization_mapping,
+ stats=dataset_stats,
+ ),
+ SARMEncodingProcessorStep(
+ config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
+ ),
+ DeviceProcessorStep(device=config.device),
+ ],
+ name=POLICY_PREPROCESSOR_DEFAULT_NAME,
+ ),
+ PolicyProcessorPipeline[PolicyAction, PolicyAction](
+ steps=[DeviceProcessorStep(device="cpu")],
+ name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
+ to_transition=policy_action_to_transition,
+ to_output=transition_to_policy_action,
+ ),
+ )
diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py
new file mode 100644
index 000000000..5b6955d38
--- /dev/null
+++ b/src/lerobot/policies/sarm/sarm_utils.py
@@ -0,0 +1,295 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+
+import numpy as np
+import torch
+import torch.nn.functional as F # noqa: N812
+
+
+def find_stage_and_tau(
+ current_frame: int,
+ episode_length: int,
+ subtask_names: list | None,
+ subtask_start_frames: list | None,
+ subtask_end_frames: list | None,
+ global_subtask_names: list,
+ temporal_proportions: dict,
+ return_combined: bool = False,
+) -> tuple[int, float] | float:
+ """Find stage and within-stage progress (tau) for a frame.
+
+ Args:
+ current_frame: Frame index relative to episode start
+ episode_length: Total frames in episode
+ subtask_names: Subtask names for this episode (None for single_stage)
+ subtask_start_frames: Subtask start frames
+ subtask_end_frames: Subtask end frames
+ global_subtask_names: Global list of all subtask names
+ temporal_proportions: Dict of temporal proportions
+ return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
+
+ Returns:
+ Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
+ """
+ stage_idx, tau = 0, 0.0
+ num_stages = len(global_subtask_names)
+
+ # Single-stage mode: linear progress from 0 to 1
+ if num_stages == 1:
+ tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
+ elif subtask_names is None:
+ pass # stage_idx=0, tau=0.0
+ elif current_frame < subtask_start_frames[0]:
+ pass # Before first subtask: stage_idx=0, tau=0.0
+ elif current_frame > subtask_end_frames[-1]:
+ stage_idx, tau = num_stages - 1, 0.999 # After last subtask
+ else:
+ # Find which subtask this frame belongs to
+ found = False
+ for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
+ if start <= current_frame <= end:
+ stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
+ tau = compute_tau(current_frame, start, end)
+ found = True
+ break
+ # Frame between subtasks - use previous subtask's end state
+ if not found:
+ for j in range(len(subtask_names) - 1):
+ if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
+ name = subtask_names[j]
+ stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
+ tau = 1.0
+ break
+
+ if return_combined:
+ # Clamp to avoid overflow at end
+ if stage_idx >= num_stages - 1 and tau >= 1.0:
+ return num_stages - 1 + 0.999
+ return stage_idx + tau
+ return stage_idx, tau
+
+
+def compute_absolute_indices(
+ frame_idx: int,
+ ep_start: int,
+ ep_end: int,
+ n_obs_steps: int,
+ frame_gap: int = 30,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Compute absolute frame indices with clamping for bidirectional observation sequence.
+
+ Bidirectional sampling centered on target frame:
+ - Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
+ - Current: [0] (1 frame)
+ - After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
+ - Total: n_obs_steps + 1 frames
+
+ Out-of-bounds frames are clamped (duplicated from boundary).
+
+ Args:
+ frame_idx: Target frame index (center frame of sequence)
+ ep_start: Episode start index
+ ep_end: Episode end index (exclusive)
+ n_obs_steps: Number of observation steps (must be even for symmetric sampling)
+ frame_gap: Gap between observation frames
+
+ Returns:
+ Tuple of (indices, out_of_bounds_flags)
+ """
+ half_steps = n_obs_steps // 2
+
+ # Bidirectional deltas: past + current + future
+ past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
+ future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
+ delta_indices = past_deltas + [0] + future_deltas
+
+ frames = []
+ out_of_bounds = []
+
+ for delta in delta_indices:
+ target_idx = frame_idx + delta
+ # Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
+ clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
+ frames.append(clamped_idx)
+ # Flag as out-of-bounds if clamping occurred
+ out_of_bounds.append(1 if target_idx != clamped_idx else 0)
+
+ return torch.tensor(frames), torch.tensor(out_of_bounds)
+
+
+def apply_rewind_augmentation(
+ frame_idx: int,
+ ep_start: int,
+ n_obs_steps: int,
+ max_rewind_steps: int,
+ frame_gap: int = 30,
+ rewind_step: int | None = None,
+) -> tuple[int, list[int]]:
+ """
+ Generate rewind frame indices for temporal augmentation.
+
+ Rewind simulates going backwards through previously seen frames,
+ starting from before the earliest observation frame (for bidirectional sampling).
+ Appends reversed frames after the observation sequence.
+
+ Args:
+ frame_idx: Target frame index (center of bidirectional observation window)
+ ep_start: Episode start index
+ n_obs_steps: Number of observation steps
+ max_rewind_steps: Maximum rewind steps
+ frame_gap: Gap between frames
+ rewind_step: If provided, use this exact rewind step (for deterministic behavior).
+ If None, sample randomly.
+
+ Returns:
+ Tuple of (rewind_step, rewind_indices)
+ """
+ # For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
+ half_steps = n_obs_steps // 2
+ earliest_obs_frame = frame_idx - half_steps * frame_gap
+
+ # Required history: frames before earliest observation frame
+ if earliest_obs_frame <= ep_start:
+ return 0, [] # No history before observation window
+
+ # Max valid rewind steps based on available history before earliest obs frame
+ available_history = earliest_obs_frame - ep_start
+ max_valid_step = available_history // frame_gap
+ max_rewind = min(max_rewind_steps, max(0, max_valid_step))
+
+ if max_rewind <= 0:
+ return 0, []
+
+ # Sample rewind steps if not provided
+ rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
+
+ if rewind_step == 0:
+ return 0, []
+
+ # Generate rewind indices going backwards from earliest obs frame
+ # rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
+ rewind_indices = []
+ for i in range(1, rewind_step + 1):
+ idx = earliest_obs_frame - i * frame_gap
+ idx = max(ep_start, idx) # Clamp to episode start
+ rewind_indices.append(idx)
+
+ return rewind_step, rewind_indices
+
+
+def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
+ """Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
+ duration = subtask_end - subtask_start
+ if duration <= 0:
+ return 1.0
+ return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
+
+
+def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
+ """Pad the state tensor's last dimension to max_state_dim with zeros."""
+ current_dim = state.shape[-1]
+ if current_dim >= max_state_dim:
+ return state[..., :max_state_dim] # Truncate if larger
+
+ # Pad with zeros on the right
+ padding = (0, max_state_dim - current_dim) # (left, right) for last dim
+ return F.pad(state, padding, mode="constant", value=0)
+
+
+def temporal_proportions_to_breakpoints(
+ temporal_proportions: dict[str, float] | list[float] | None,
+ subtask_names: list[str] | None = None,
+) -> list[float] | None:
+ """Convert temporal proportions to cumulative breakpoints for normalization."""
+ if temporal_proportions is None:
+ return None
+
+ if isinstance(temporal_proportions, dict):
+ if subtask_names is not None:
+ proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
+ else:
+ proportions = list(temporal_proportions.values())
+ else:
+ proportions = list(temporal_proportions)
+
+ total = sum(proportions)
+ if total > 0 and abs(total - 1.0) > 1e-6:
+ proportions = [p / total for p in proportions]
+
+ breakpoints = [0.0]
+ cumsum = 0.0
+ for prop in proportions:
+ cumsum += prop
+ breakpoints.append(cumsum)
+ breakpoints[-1] = 1.0
+
+ return breakpoints
+
+
+def normalize_stage_tau(
+ x: float | torch.Tensor,
+ num_stages: int | None = None,
+ breakpoints: list[float] | None = None,
+ temporal_proportions: dict[str, float] | list[float] | None = None,
+ subtask_names: list[str] | None = None,
+) -> float | torch.Tensor:
+ """
+ Normalize stage+tau reward to [0, 1] with custom breakpoints.
+
+ Maps stage index + within-stage tau to normalized progress [0, 1].
+ The breakpoints are designed to give appropriate weight to each stage
+ based on their importance in the task (using temporal proportions).
+
+ Priority: breakpoints > temporal_proportions > linear fallback
+
+ Args:
+ x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1)
+ num_stages: Number of stages (required if breakpoints/proportions not provided)
+ breakpoints: Optional custom breakpoints list of length num_stages + 1.
+ temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
+ subtask_names: Optional ordered list of subtask names (for dict proportions)
+
+ Returns:
+ Normalized progress value ∈ [0, 1]
+ """
+ if breakpoints is not None:
+ num_stages = len(breakpoints) - 1
+ elif temporal_proportions is not None:
+ breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
+ num_stages = len(breakpoints) - 1
+ elif num_stages is not None:
+ breakpoints = [i / num_stages for i in range(num_stages + 1)]
+ else:
+ raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
+
+ if isinstance(x, torch.Tensor):
+ result = torch.zeros_like(x)
+ for i in range(num_stages):
+ mask = (x >= i) & (x < i + 1)
+ tau_in_stage = x - i
+ result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
+ result[x >= num_stages] = 1.0
+ return result.clamp(0.0, 1.0)
+ else:
+ if x < 0:
+ return 0.0
+ if x >= num_stages:
+ return 1.0
+ stage = int(x)
+ tau = x - stage
+ return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py
index 485d3e4e5..f998661f9 100644
--- a/src/lerobot/policies/smolvla/modeling_smolvla.py
+++ b/src/lerobot/policies/smolvla/modeling_smolvla.py
@@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
def __init__(
self,
config: SmolVLAConfig,
+ **kwargs,
):
"""
Args:
@@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy):
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
- def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
- """Do a full training forward pass to compute the loss"""
+ def forward(
+ self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
+ ) -> dict[str, Tensor]:
+ """Do a full training forward pass to compute the loss.
+
+ Args:
+ batch: Training batch containing observations and actions.
+ noise: Optional noise tensor for flow matching.
+ time: Optional time tensor for flow matching.
+ reduction: How to reduce the loss. Options:
+ - "mean": Return scalar mean loss (default, backward compatible)
+ - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
+ """
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
@@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
- # For backward pass
- loss = losses.mean()
- # For backward pass
- loss_dict["loss"] = loss.item()
- return loss, loss_dict
+ if reduction == "none":
+ # Return per-sample losses (B,) by averaging over time and action dims
+ per_sample_loss = losses.mean(dim=(1, 2))
+ loss_dict["loss"] = per_sample_loss.mean().item()
+ return per_sample_loss, loss_dict
+ else:
+ # Default: return scalar mean loss
+ loss = losses.mean()
+ loss_dict["loss"] = loss.item()
+ return loss, loss_dict
def prepare_images(self, batch):
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
index 195cf6154..f83c82e21 100644
--- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py
+++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
@@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy):
def __init__(
self,
config: TDMPCConfig,
+ **kwargs,
):
"""
Args:
diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py
index c4ca35b72..bfbe2bf1d 100644
--- a/src/lerobot/policies/utils.py
+++ b/src/lerobot/policies/utils.py
@@ -231,11 +231,20 @@ def validate_visual_features_consistency(
"""
Validates visual feature consistency between a policy config and provided dataset/environment features.
+ Validation passes if EITHER:
+ - Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
+ - Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
+
Args:
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
"""
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
- if not provided_visuals.issubset(expected_visuals):
+
+ # Accept if either direction is a subset
+ policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
+ dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
+
+ if not (policy_subset_of_dataset or dataset_subset_of_policy):
raise_feature_mismatch_error(provided_visuals, expected_visuals)
diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py
index 91d609701..359b4fdb1 100644
--- a/src/lerobot/policies/vqbet/modeling_vqbet.py
+++ b/src/lerobot/policies/vqbet/modeling_vqbet.py
@@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy):
def __init__(
self,
config: VQBeTConfig | None = None,
+ **kwargs,
):
"""
Args:
diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py
index 27c7c6e1b..0436ae527 100644
--- a/src/lerobot/policies/xvla/modeling_xvla.py
+++ b/src/lerobot/policies/xvla/modeling_xvla.py
@@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy):
config_class = XVLAConfig
name = "xvla"
- def __init__(self, config: XVLAConfig):
+ def __init__(self, config: XVLAConfig, **kwargs):
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py
index 6b0b67598..126be0e36 100644
--- a/src/lerobot/processor/converters.py
+++ b/src/lerobot/processor/converters.py
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
+ episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
- return {**pad_keys, **task_key, **index_key, **task_index_key}
+ return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py
index 1ebdee600..6cf733442 100644
--- a/src/lerobot/scripts/lerobot_train.py
+++ b/src/lerobot/scripts/lerobot_train.py
@@ -62,6 +62,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
+ rabc_weights_provider=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -78,6 +79,7 @@ def update_policy(
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
lock: An optional lock for thread-safe optimizer updates.
+ rabc_weights_provider: Optional RABCWeights instance for sample weighting.
Returns:
A tuple containing:
@@ -87,9 +89,30 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
+ # Get RA-BC weights if enabled
+ rabc_batch_weights = None
+ rabc_batch_stats = None
+ if rabc_weights_provider is not None:
+ rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
+
# Let accelerator handle mixed precision
with accelerator.autocast():
- loss, output_dict = policy.forward(batch)
+ # Use per-sample loss when RA-BC is enabled for proper weighting
+ if rabc_batch_weights is not None:
+ # Get per-sample losses
+ per_sample_loss, output_dict = policy.forward(batch, reduction="none")
+
+ # Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
+ # rabc_batch_weights is already normalized to sum to batch_size
+ epsilon = 1e-6
+ loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
+ # Log raw mean weight (before normalization) - this is the meaningful metric
+ output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
+ output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
+ output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
+ else:
+ loss, output_dict = policy.forward(batch)
+
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
- cfg.validate()
-
# Create Accelerator if not provided
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
@@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# When using accelerate, only the main process should log to avoid duplicate outputs
is_main_process = accelerator.is_main_process
+ cfg.validate()
+
# Only log on main process
if is_main_process:
logging.info(pformat(cfg.to_dict()))
@@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
+ # For SARM, always provide dataset_meta for progress normalization
+ if cfg.policy.type == "sarm":
+ processor_kwargs["dataset_meta"] = dataset.meta
+
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
@@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
+ # Load precomputed SARM progress for RA-BC if enabled
+ # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
+ rabc_weights = None
+ if cfg.use_rabc:
+ from lerobot.utils.rabc import RABCWeights
+
+ # Get chunk_size from policy config
+ chunk_size = getattr(policy.config, "chunk_size", None)
+ if chunk_size is None:
+ raise ValueError("Chunk size is not found in policy config")
+
+ head_mode = getattr(cfg, "rabc_head_mode", "sparse")
+ logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
+ logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
+ rabc_weights = RABCWeights(
+ progress_path=cfg.rabc_progress_path,
+ chunk_size=chunk_size,
+ head_mode=head_mode,
+ kappa=getattr(cfg, "rabc_kappa", 0.01),
+ epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
+ device=device,
+ )
+
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
@@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
- logging.info("Start offline training on a fixed dataset")
+ logging.info(
+ f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
+ )
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
@@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
+ rabc_weights_provider=rabc_weights,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
+ # Log RA-BC statistics if enabled
+ if rabc_weights is not None:
+ rabc_stats = rabc_weights.get_stats()
+ wandb_log_dict.update(
+ {
+ "rabc_delta_mean": rabc_stats["delta_mean"],
+ "rabc_delta_std": rabc_stats["delta_std"],
+ "rabc_num_frames": rabc_stats["num_frames"],
+ }
+ )
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py
new file mode 100644
index 000000000..c529f3ccc
--- /dev/null
+++ b/src/lerobot/utils/rabc.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+
+
+class RABCWeights:
+ """
+ Load precomputed SARM progress values and compute RA-BC weights during training.
+
+ Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
+ During training, computes:
+ - progress_delta = progress[t + chunk_size] - progress[t]
+ - rabc_weight based on the delta (paper Eq. 8-9)
+
+ Args:
+ progress_path: Path to parquet file with precomputed progress values
+ chunk_size: Number of frames ahead for computing progress delta
+ head_mode: Which SARM head to use ("sparse" or "dense")
+ kappa: Hard threshold for high-quality samples (default: 0.01)
+ epsilon: Small constant for numerical stability (default: 1e-6)
+ fallback_weight: Weight to use for frames without valid delta (default: 1.0)
+ device: Device to return tensors on
+ """
+
+ def __init__(
+ self,
+ progress_path: str | Path,
+ chunk_size: int = 50,
+ head_mode: str = "sparse",
+ kappa: float = 0.01,
+ epsilon: float = 1e-6,
+ fallback_weight: float = 1.0,
+ device: torch.device = None,
+ ):
+ self.progress_path = Path(progress_path)
+ self.chunk_size = chunk_size
+ self.head_mode = head_mode
+ self.kappa = kappa
+ self.epsilon = epsilon
+ self.fallback_weight = fallback_weight
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Determine progress column name
+ self.progress_column = f"progress_{head_mode}"
+
+ # Load progress values
+ logging.info(f"Loading SARM progress values from {self.progress_path}")
+ self.df = pd.read_parquet(self.progress_path)
+
+ # Check if the requested head mode column exists
+ if self.progress_column not in self.df.columns:
+ available = [c for c in self.df.columns if c.startswith("progress")]
+ raise ValueError(
+ f"Column '{self.progress_column}' not found. Available progress columns: {available}"
+ )
+
+ logging.info(f"Using progress column: {self.progress_column}")
+
+ self.progress_lookup = {}
+ self.episode_lookup = {}
+
+ for _, row in self.df.iterrows():
+ global_idx = int(row["index"])
+ progress = row[self.progress_column]
+ episode_idx = int(row["episode_index"])
+
+ if not np.isnan(progress):
+ self.progress_lookup[global_idx] = float(progress)
+ self.episode_lookup[global_idx] = episode_idx
+
+ # Build episode boundaries for delta computation
+ self.episode_boundaries = {}
+ for episode_idx in self.df["episode_index"].unique():
+ ep_df = self.df[self.df["episode_index"] == episode_idx]
+ self.episode_boundaries[int(episode_idx)] = {
+ "start": int(ep_df["index"].min()),
+ "end": int(ep_df["index"].max()) + 1,
+ }
+
+ logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
+ logging.info(f"Chunk size for delta computation: {chunk_size}")
+
+ # Compute global statistics for weight computation
+ self._compute_global_stats()
+
+ def _compute_global_stats(self):
+ """Compute global mean and std of progress deltas for weight calculation."""
+ all_deltas = []
+
+ for global_idx, progress in self.progress_lookup.items():
+ episode_idx = self.episode_lookup.get(global_idx)
+ if episode_idx is None:
+ continue
+
+ bounds = self.episode_boundaries.get(episode_idx)
+ if bounds is None:
+ continue
+
+ future_idx = global_idx + self.chunk_size
+ if future_idx >= bounds["end"]:
+ # Near end of episode: use last frame's progress
+ future_idx = bounds["end"] - 1
+
+ future_progress = self.progress_lookup.get(future_idx)
+ if future_progress is not None:
+ delta = future_progress - progress
+ all_deltas.append(delta)
+
+ if all_deltas:
+ self.delta_mean = max(np.mean(all_deltas), 0.0)
+ self.delta_std = max(np.std(all_deltas), self.epsilon)
+ logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
+ else:
+ self.delta_mean = 0.0
+ self.delta_std = self.epsilon
+ logging.warning("No valid progress deltas found, using default stats")
+
+ def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
+ """
+ Compute RA-BC weights for a batch.
+
+ For each sample:
+ 1. Get progress at current frame
+ 2. Get progress at frame + chunk_size (within same episode)
+ 3. Compute delta = future_progress - current_progress
+ 4. Compute weight using paper Eq. 8-9
+
+ Args:
+ batch: Training batch containing "index" key with global frame indices
+
+ Returns:
+ Tuple of:
+ - Weights tensor (batch_size,) normalized to sum to batch_size
+ - Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
+ """
+ indices = batch.get("index")
+ if indices is None:
+ logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
+ batch_size = self._get_batch_size(batch)
+ return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
+
+ # Convert to list of ints
+ if isinstance(indices, torch.Tensor):
+ indices = indices.cpu().numpy().tolist()
+ elif isinstance(indices, np.ndarray):
+ indices = indices.tolist()
+
+ # Compute deltas and weights for each sample
+ deltas = []
+ for idx in indices:
+ idx = int(idx)
+ delta = self._compute_delta(idx)
+ deltas.append(delta)
+
+ deltas = np.array(deltas, dtype=np.float32)
+
+ # Compute weights from deltas
+ weights = self._compute_weights(deltas)
+
+ # Compute stats before normalization for logging
+ raw_mean_weight = float(np.nanmean(weights))
+ num_zero_weight = int(np.sum(weights == 0))
+ num_full_weight = int(np.sum(weights == 1.0))
+ batch_stats = {
+ "raw_mean_weight": raw_mean_weight,
+ "num_zero_weight": num_zero_weight,
+ "num_full_weight": num_full_weight,
+ }
+
+ weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
+
+ # Normalize to sum to batch_size
+ batch_size = len(weights)
+ weight_sum = weights.sum() + self.epsilon
+ weights = weights * batch_size / weight_sum
+
+ return weights, batch_stats
+
+ def _compute_delta(self, global_idx: int) -> float:
+ """Compute progress delta for a single frame."""
+ current_progress = self.progress_lookup.get(global_idx)
+ if current_progress is None:
+ return np.nan
+
+ episode_idx = self.episode_lookup.get(global_idx)
+ if episode_idx is None:
+ return np.nan
+
+ bounds = self.episode_boundaries.get(episode_idx)
+ if bounds is None:
+ return np.nan
+
+ future_idx = global_idx + self.chunk_size # Δ = chunk_size
+ if future_idx >= bounds["end"]:
+ # Near end of episode: use last frame's progress instead
+ future_idx = bounds["end"] - 1
+
+ future_progress = self.progress_lookup.get(future_idx)
+ if future_progress is None:
+ return np.nan
+
+ return future_progress - current_progress
+
+ def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
+ """
+ Compute RA-BC weights from progress deltas.
+
+ Following paper Eq. 8-9:
+ - Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
+ - Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
+
+ Returns:
+ Array of weights
+ """
+ valid_mask = ~np.isnan(deltas)
+
+ # Compute soft weights using global statistics
+ lower_bound = self.delta_mean - 2 * self.delta_std
+ soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
+ soft_weights = np.clip(soft_weights, 0.0, 1.0)
+
+ # Apply paper's Eq. 9
+ weights = np.zeros_like(deltas, dtype=np.float32)
+
+ # High quality: ri > kappa → weight = 1
+ high_quality_mask = deltas > self.kappa
+ weights[high_quality_mask] = 1.0
+
+ # Moderate quality: 0 <= ri <= kappa → weight = soft_weight
+ moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
+ weights[moderate_mask] = soft_weights[moderate_mask]
+
+ # Negative progress: ri < 0 → weight = 0 (already 0)
+ # Invalid (NaN): use fallback weight
+ weights[~valid_mask] = self.fallback_weight
+
+ return weights
+
+ def _get_batch_size(self, batch: dict) -> int:
+ """Determine batch size from batch."""
+ for key in ["action", "index"]:
+ if key in batch:
+ val = batch[key]
+ if isinstance(val, (torch.Tensor, np.ndarray)):
+ return val.shape[0]
+ return 1
+
+ def get_stats(self) -> dict:
+ """Get statistics."""
+ return {
+ "num_frames": len(self.progress_lookup),
+ "chunk_size": self.chunk_size,
+ "head_mode": self.head_mode,
+ "delta_mean": self.delta_mean,
+ "delta_std": self.delta_std,
+ "kappa": self.kappa,
+ }
diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py
new file mode 100644
index 000000000..66404f663
--- /dev/null
+++ b/tests/policies/test_sarm_processor.py
@@ -0,0 +1,694 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+pytest.importorskip("faker")
+
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+
+from lerobot.processor.core import TransitionKey
+
+
+class MockDatasetMeta:
+ """Mock dataset metadata for testing processor."""
+
+ def __init__(self, episodes: list[dict]):
+ self._episodes = episodes
+
+ @property
+ def episodes(self):
+ """Return episodes as a mock object with to_pandas() method."""
+ mock = MagicMock()
+ mock.__len__ = lambda s: len(self._episodes)
+ mock.__getitem__ = lambda s, idx: self._episodes[idx]
+ mock.to_pandas = lambda: pd.DataFrame(self._episodes)
+ return mock
+
+
+class MockConfig:
+ """Mock SARMConfig for testing processor methods."""
+
+ def __init__(
+ self,
+ n_obs_steps: int = 8,
+ max_rewind_steps: int = 4,
+ frame_gap: int = 30,
+ sparse_subtask_names: list = None,
+ sparse_temporal_proportions: list = None,
+ dense_subtask_names: list = None,
+ dense_temporal_proportions: list = None,
+ image_key: str = "observation.images.top",
+ state_key: str = "observation.state",
+ max_state_dim: int = 32,
+ device: str = None,
+ rewind_probability: float = 0.8,
+ language_perturbation_probability: float = 0.2,
+ annotation_mode: str = "dual",
+ clip_batch_size: int = 64,
+ text_dim: int = 512,
+ ):
+ self.n_obs_steps = n_obs_steps
+ self.max_rewind_steps = max_rewind_steps
+ self.frame_gap = frame_gap
+ self.sparse_subtask_names = sparse_subtask_names or ["task"]
+ self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
+ self.dense_subtask_names = dense_subtask_names
+ self.dense_temporal_proportions = dense_temporal_proportions
+ self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
+ self.image_key = image_key
+ self.state_key = state_key
+ self.max_state_dim = max_state_dim
+ self.device = device
+ self.rewind_probability = rewind_probability
+ self.language_perturbation_probability = language_perturbation_probability
+ self.annotation_mode = annotation_mode
+ self.clip_batch_size = clip_batch_size
+ self.text_dim = text_dim
+
+ # Compute observation delta indices (same as config: bidirectional)
+ half_steps = self.n_obs_steps // 2
+ past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
+ future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
+ obs_deltas = past_deltas + [0] + future_deltas
+ rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
+ self.observation_delta_indices = obs_deltas + rewind_deltas
+
+ @property
+ def num_frames(self) -> int:
+ return 1 + self.n_obs_steps + self.max_rewind_steps
+
+
+class TestSARMEncodingProcessorStepEndToEnd:
+ """End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
+
+ @pytest.fixture
+ def mock_clip_model(self):
+ """Mock CLIP model to avoid loading real weights."""
+ with (
+ patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
+ patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
+ ):
+ # Mock the CLIP model - return embeddings based on input batch size
+ mock_model = MagicMock()
+
+ def get_image_features_side_effect(**kwargs):
+ pixel_values = kwargs.get("pixel_values")
+ batch_size = pixel_values.shape[0] if pixel_values is not None else 1
+ return torch.randn(batch_size, 512)
+
+ mock_model.get_image_features.side_effect = get_image_features_side_effect
+ mock_model.get_text_features.return_value = torch.randn(1, 512)
+ mock_model.to.return_value = mock_model
+ mock_model_cls.from_pretrained.return_value = mock_model
+
+ # Mock the CLIP processor - return tensors based on input images
+ mock_processor = MagicMock()
+
+ def processor_side_effect(images=None, **kwargs):
+ num_images = len(images) if images is not None else 1
+ return {
+ "pixel_values": torch.randn(num_images, 3, 224, 224),
+ }
+
+ mock_processor.side_effect = processor_side_effect
+ # Mock tokenizer for text encoding
+ mock_processor.tokenizer.return_value = {
+ "input_ids": torch.ones(1, 77, dtype=torch.long),
+ "attention_mask": torch.ones(1, 77, dtype=torch.long),
+ }
+ mock_processor_cls.from_pretrained.return_value = mock_processor
+
+ yield mock_model, mock_processor
+
+ @pytest.fixture
+ def processor_with_mocks(self, mock_clip_model):
+ """Create a processor with mocked CLIP and dataset metadata for dual mode."""
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+
+ # Dual mode config with both sparse and dense annotations
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=30,
+ rewind_probability=0.0, # Disable for deterministic test
+ language_perturbation_probability=0.0, # Disable for deterministic test
+ annotation_mode="dual",
+ sparse_subtask_names=["reach", "grasp", "lift"],
+ sparse_temporal_proportions=[0.3, 0.4, 0.3],
+ dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
+ dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
+ )
+
+ # Create mock dataset metadata with one episode of 300 frames
+ # Include annotation columns for dual mode
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 300,
+ "task": "pick up the cube",
+ "sparse_subtask_names": ["reach", "grasp", "lift"],
+ "sparse_subtask_start_frames": [0, 90, 210],
+ "sparse_subtask_end_frames": [90, 210, 300],
+ "dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
+ "dense_subtask_start_frames": [0, 75, 150, 225],
+ "dense_subtask_end_frames": [75, 150, 225, 300],
+ }
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(
+ config=config,
+ dataset_meta=dataset_meta,
+ )
+ processor.train(True) # Use train() method, not direct assignment
+
+ return processor, config
+
+ def test_call_with_single_frame_batch(self, processor_with_mocks):
+ """Test processor __call__ with a single-frame batch."""
+ processor, config = processor_with_mocks
+
+ # Create dummy input transition
+ batch_size = 1
+ num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
+
+ # Image: (T, C, H, W) format as expected by processor
+ dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
+
+ # State: (T, D) format
+ dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": 150, # Middle of episode
+ "episode_index": 0,
+ "task": "pick up the cube",
+ },
+ }
+
+ # Run processor
+ result = processor(transition)
+
+ # Verify output structure
+ obs = result[TransitionKey.OBSERVATION]
+
+ # Check video features exist and have correct shape
+ assert "video_features" in obs
+ video_features = obs["video_features"]
+ assert video_features.shape[0] == batch_size
+ assert video_features.shape[1] == num_frames
+ assert video_features.shape[2] == 512 # CLIP embedding dim
+
+ # Check state features exist and have correct shape
+ assert "state_features" in obs
+ state_features = obs["state_features"]
+ assert state_features.shape[0] == batch_size
+ assert state_features.shape[1] == num_frames
+ assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
+
+ # Check text features exist and have correct shape
+ assert "text_features" in obs
+ text_features = obs["text_features"]
+ assert text_features.shape[0] == batch_size
+ assert text_features.shape[1] == 512 # CLIP embedding dim
+
+ # Check lengths tensor
+ assert "lengths" in obs
+ lengths = obs["lengths"]
+ assert lengths.shape[0] == batch_size
+ assert lengths.dtype == torch.int32
+
+ # Check sparse_targets exist
+ assert "sparse_targets" in obs
+ sparse_targets = obs["sparse_targets"]
+ assert sparse_targets.shape == (batch_size, num_frames)
+ # All targets should be in [0, max_stages] range (stage.tau format)
+ assert (sparse_targets >= 0).all()
+
+ # Check dense_targets exist (for dual mode)
+ assert "dense_targets" in obs
+ dense_targets = obs["dense_targets"]
+ assert dense_targets.shape == (batch_size, num_frames)
+ assert (dense_targets >= 0).all()
+
+ def test_call_with_batched_input(self, mock_clip_model):
+ """Test processor __call__ with a batched input (multiple frames) in dual mode."""
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=30,
+ rewind_probability=0.0,
+ language_perturbation_probability=0.0,
+ annotation_mode="dual",
+ sparse_subtask_names=["reach", "grasp"],
+ sparse_temporal_proportions=[0.5, 0.5],
+ dense_subtask_names=["step1", "step2", "step3"],
+ dense_temporal_proportions=[0.33, 0.34, 0.33],
+ )
+
+ # Two episodes with different lengths, each with sparse+dense annotations
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 200,
+ "task": "task A",
+ "sparse_subtask_names": ["reach", "grasp"],
+ "sparse_subtask_start_frames": [0, 100],
+ "sparse_subtask_end_frames": [100, 200],
+ "dense_subtask_names": ["step1", "step2", "step3"],
+ "dense_subtask_start_frames": [0, 66, 133],
+ "dense_subtask_end_frames": [66, 133, 200],
+ },
+ {
+ "dataset_from_index": 200,
+ "dataset_to_index": 500,
+ "task": "task B",
+ "sparse_subtask_names": ["reach", "grasp"],
+ "sparse_subtask_start_frames": [200, 350],
+ "sparse_subtask_end_frames": [350, 500],
+ "dense_subtask_names": ["step1", "step2", "step3"],
+ "dense_subtask_start_frames": [200, 300, 400],
+ "dense_subtask_end_frames": [300, 400, 500],
+ },
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
+ processor.train(True)
+
+ batch_size = 2
+ num_frames = config.num_frames
+
+ # Image: (B, T, C, H, W) format
+ dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
+ dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": np.array([100, 350]), # One frame from each episode
+ "episode_index": np.array([0, 1]),
+ "task": ["task A", "task B"],
+ },
+ }
+
+ result = processor(transition)
+ obs = result[TransitionKey.OBSERVATION]
+
+ # Verify batch dimension is preserved for all outputs
+ assert obs["video_features"].shape[0] == batch_size
+ assert obs["state_features"].shape[0] == batch_size
+ assert obs["lengths"].shape[0] == batch_size
+ assert obs["sparse_targets"].shape[0] == batch_size
+ assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
+
+ def test_targets_increase_with_progress(self, mock_clip_model):
+ """Test that both sparse and dense targets increase as frame index progresses."""
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=30,
+ rewind_probability=0.0,
+ language_perturbation_probability=0.0,
+ annotation_mode="dual",
+ sparse_subtask_names=["phase1", "phase2"],
+ sparse_temporal_proportions=[0.5, 0.5],
+ dense_subtask_names=["a", "b", "c", "d"],
+ dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
+ )
+
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 300,
+ "task": "test task",
+ "sparse_subtask_names": ["phase1", "phase2"],
+ "sparse_subtask_start_frames": [0, 150],
+ "sparse_subtask_end_frames": [150, 300],
+ "dense_subtask_names": ["a", "b", "c", "d"],
+ "dense_subtask_start_frames": [0, 75, 150, 225],
+ "dense_subtask_end_frames": [75, 150, 225, 300],
+ }
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
+ processor.train(True)
+
+ num_frames = config.num_frames
+
+ # Test at early, middle, and late points in episode
+ frame_indices = [30, 150, 270]
+ sparse_center_targets = []
+ dense_center_targets = []
+
+ for frame_idx in frame_indices:
+ dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
+ dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": frame_idx,
+ "episode_index": 0,
+ "task": "test task",
+ },
+ }
+
+ result = processor(transition)
+ obs = result[TransitionKey.OBSERVATION]
+ # Get target at center frame (index 4 in 9-frame observation window)
+ sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
+ dense_center_targets.append(obs["dense_targets"][0, 4].item())
+
+ # Both sparse and dense targets should increase with frame index
+ assert sparse_center_targets[0] < sparse_center_targets[2], (
+ f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
+ )
+ assert dense_center_targets[0] < dense_center_targets[2], (
+ f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
+ )
+
+ def test_progress_labels_exact_values(self, mock_clip_model):
+ """Test that progress labels (stage.tau) are computed correctly for known positions."""
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+
+ # Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=10, # Smaller gap for easier calculation
+ rewind_probability=0.0,
+ language_perturbation_probability=0.0,
+ annotation_mode="dual",
+ sparse_subtask_names=["A", "B"],
+ sparse_temporal_proportions=[0.5, 0.5],
+ dense_subtask_names=["d1", "d2", "d3", "d4"],
+ dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
+ )
+
+ # Episode: frames 0-99, sparse stages at [0-49], [50-99]
+ # Dense stages at [0-24], [25-49], [50-74], [75-99]
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 100,
+ "task": "test",
+ "sparse_subtask_names": ["A", "B"],
+ "sparse_subtask_start_frames": [0, 50],
+ "sparse_subtask_end_frames": [50, 100],
+ "dense_subtask_names": ["d1", "d2", "d3", "d4"],
+ "dense_subtask_start_frames": [0, 25, 50, 75],
+ "dense_subtask_end_frames": [25, 50, 75, 100],
+ }
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
+ processor.train(True)
+
+ num_frames = config.num_frames
+
+ # Test at frame 50 (center of episode)
+ # With frame_gap=10, n_obs_steps=8:
+ # obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
+ dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
+ dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": 50,
+ "episode_index": 0,
+ "task": "test",
+ },
+ }
+
+ result = processor(transition)
+ obs = result[TransitionKey.OBSERVATION]
+ sparse_targets = obs["sparse_targets"][0] # (13,)
+ dense_targets = obs["dense_targets"][0] # (13,)
+
+ # First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
+ # Check that obs frames have non-zero targets
+ obs_sparse = sparse_targets[:9]
+ obs_dense = dense_targets[:9]
+
+ # Verify targets are monotonically increasing for observation frames
+ for i in range(1, 9):
+ assert obs_sparse[i] >= obs_sparse[i - 1], (
+ f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
+ )
+ assert obs_dense[i] >= obs_dense[i - 1], (
+ f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
+ )
+
+ # Rewind slots should be zero when rewind is disabled
+ rewind_targets = sparse_targets[9:]
+ assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
+
+ # Check stage transitions: frame 50 is at boundary of sparse stage A->B
+ # Center frame (index 4) corresponds to actual frame 50
+ center_sparse = obs_sparse[4].item()
+ # At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
+ assert 0.9 <= center_sparse <= 1.1, (
+ f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
+ )
+
+ def test_rewind_augmentation_applied(self, mock_clip_model):
+ """Test that rewind augmentation correctly extends sequence and generates targets."""
+ import random
+
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=10,
+ rewind_probability=1.0, # Always apply rewind
+ language_perturbation_probability=0.0,
+ annotation_mode="dual",
+ sparse_subtask_names=["A", "B"],
+ sparse_temporal_proportions=[0.5, 0.5],
+ dense_subtask_names=["d1", "d2"],
+ dense_temporal_proportions=[0.5, 0.5],
+ )
+
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 200,
+ "task": "test",
+ "sparse_subtask_names": ["A", "B"],
+ "sparse_subtask_start_frames": [0, 100],
+ "sparse_subtask_end_frames": [100, 200],
+ "dense_subtask_names": ["d1", "d2"],
+ "dense_subtask_start_frames": [0, 100],
+ "dense_subtask_end_frames": [100, 200],
+ }
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
+ processor.train(True)
+
+ num_frames = config.num_frames # 13
+
+ # Test at frame 150 (center of bidirectional window)
+ # With n_obs_steps=8, half_steps=4, frame_gap=10:
+ # - Earliest obs frame = 150 - 4*10 = 110
+ # - Rewind can go back from 110 to frames like 100, 90, 80, 70
+ # - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
+ dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
+ dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": 150,
+ "episode_index": 0,
+ "task": "test",
+ },
+ }
+
+ # Seed random for reproducibility
+ random.seed(42)
+ result = processor(transition)
+ obs = result[TransitionKey.OBSERVATION]
+
+ lengths = obs["lengths"][0].item()
+ sparse_targets = obs["sparse_targets"][0]
+
+ # With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
+ assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
+ assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
+
+ # Rewind targets should be non-zero for frames within valid length
+ n_obs_frames = 9
+ rewind_count = lengths - n_obs_frames
+
+ if rewind_count > 0:
+ # Check that rewind frames have targets
+ rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
+ # Rewind frames are from BEFORE the earliest obs frame (110)
+ # These frames (100, 90, 80, 70) are earlier in the episode
+ earliest_obs_target = sparse_targets[0].item() # Frame 110
+
+ # Rewind targets should be less than earliest obs (they're from earlier frames)
+ for i, rt in enumerate(rewind_targets):
+ assert rt.item() < earliest_obs_target, (
+ f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
+ )
+
+ # Rewind targets should be decreasing (going further back in time)
+ for i in range(1, len(rewind_targets)):
+ assert rewind_targets[i] <= rewind_targets[i - 1], (
+ f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
+ )
+
+ def test_full_sequence_target_consistency(self, mock_clip_model):
+ """Test that the full sequence of targets is consistent with frame positions."""
+ from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
+ from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
+
+ config = MockConfig(
+ n_obs_steps=8,
+ max_rewind_steps=4,
+ frame_gap=10,
+ rewind_probability=0.0,
+ language_perturbation_probability=0.0,
+ annotation_mode="dual",
+ sparse_subtask_names=["s1", "s2", "s3"],
+ sparse_temporal_proportions=[0.33, 0.34, 0.33],
+ dense_subtask_names=["d1", "d2"],
+ dense_temporal_proportions=[0.5, 0.5],
+ )
+
+ # 3 sparse stages: [0-33), [33-66), [66-99]
+ # 2 dense stages: [0-50), [50-100)
+ episodes = [
+ {
+ "dataset_from_index": 0,
+ "dataset_to_index": 100,
+ "task": "test",
+ "sparse_subtask_names": ["s1", "s2", "s3"],
+ "sparse_subtask_start_frames": [0, 33, 66],
+ "sparse_subtask_end_frames": [33, 66, 100],
+ "dense_subtask_names": ["d1", "d2"],
+ "dense_subtask_start_frames": [0, 50],
+ "dense_subtask_end_frames": [50, 100],
+ }
+ ]
+ dataset_meta = MockDatasetMeta(episodes)
+
+ processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
+ processor.train(True)
+
+ num_frames = config.num_frames
+
+ # Test at frame 50 (middle of episode)
+ dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
+ dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
+
+ transition = {
+ TransitionKey.OBSERVATION: {
+ config.image_key: dummy_image,
+ config.state_key: dummy_state,
+ },
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "index": 50,
+ "episode_index": 0,
+ "task": "test",
+ },
+ }
+
+ result = processor(transition)
+ obs = result[TransitionKey.OBSERVATION]
+ sparse_targets = obs["sparse_targets"][0]
+ dense_targets = obs["dense_targets"][0]
+
+ # Manually compute expected targets for observation frames
+ # With frame_gap=10, n_obs_steps=8, center at 50:
+ # obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
+ expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
+
+ sparse_names = ["s1", "s2", "s3"]
+ sparse_starts = [0, 33, 66]
+ sparse_ends = [33, 66, 100]
+ sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
+
+ dense_names = ["d1", "d2"]
+ dense_starts = [0, 50]
+ dense_ends = [50, 100]
+ dense_props = {"d1": 0.5, "d2": 0.5}
+
+ for i, frame in enumerate(expected_obs_frames):
+ expected_sparse = find_stage_and_tau(
+ frame,
+ 100,
+ sparse_names,
+ sparse_starts,
+ sparse_ends,
+ sparse_names,
+ sparse_props,
+ return_combined=True,
+ )
+ expected_dense = find_stage_and_tau(
+ frame,
+ 100,
+ dense_names,
+ dense_starts,
+ dense_ends,
+ dense_names,
+ dense_props,
+ return_combined=True,
+ )
+
+ actual_sparse = sparse_targets[i].item()
+ actual_dense = dense_targets[i].item()
+
+ assert abs(actual_sparse - expected_sparse) < 0.01, (
+ f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
+ )
+ assert abs(actual_dense - expected_dense) < 0.01, (
+ f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
+ )
diff --git a/tests/policies/test_sarm_subtask_annotations.py b/tests/policies/test_sarm_subtask_annotations.py
new file mode 100644
index 000000000..0dc087288
--- /dev/null
+++ b/tests/policies/test_sarm_subtask_annotations.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+pytest.importorskip("transformers")
+
+from lerobot.data_processing.sarm_annotations.subtask_annotation import (
+ Subtask,
+ SubtaskAnnotation,
+ Timestamp,
+ compute_temporal_proportions,
+)
+
+
+def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
+ """Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
+ return SubtaskAnnotation(
+ subtasks=[
+ Subtask(
+ name=name,
+ timestamps=Timestamp(
+ start=f"{start // 60:02d}:{start % 60:02d}", end=f"{end // 60:02d}:{end % 60:02d}"
+ ),
+ )
+ for name, start, end in subtasks
+ ]
+ )
+
+
+class TestComputeTemporalProportions:
+ """Tests for compute_temporal_proportions (SARM Paper Formula 1).
+
+ Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
+
+ Key insight: This averages the PROPORTION of each subtask within each trajectory,
+ giving equal weight to all trajectories regardless of absolute length.
+ """
+
+ def test_basic_two_trajectories_equal_proportions(self):
+ """Test with two trajectories that have equal proportions."""
+ # Both trajectories: subtask1 = 50%, subtask2 = 50%
+ # Traj 1: T=100s, subtask1=50s, subtask2=50s
+ # Traj 2: T=200s, subtask1=100s, subtask2=100s
+ annotations = {
+ 0: make_annotation([("subtask1", 0, 50), ("subtask2", 50, 100)]),
+ 1: make_annotation([("subtask1", 0, 100), ("subtask2", 100, 200)]),
+ }
+
+ result = compute_temporal_proportions(annotations)
+
+ # Both should be 0.5
+ assert abs(result["subtask1"] - 0.5) < 1e-6
+ assert abs(result["subtask2"] - 0.5) < 1e-6
+
+ def test_paper_example_different_from_avg_durations(self):
+ """Test that compute_temporal_proportions differs from naive average duration approach.
+
+ This is the key test showing the difference between:
+ - Paper formula: average of (L_i,k / T_i)
+ - Naive approach: mean(L_i,k) / sum(mean(L_i,j))
+ """
+ # Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
+ # Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
+ annotations = {
+ 0: make_annotation([("subtask1", 0, 80), ("subtask2", 80, 100)]),
+ 1: make_annotation([("subtask1", 0, 40), ("subtask2", 40, 200)]),
+ }
+
+ result = compute_temporal_proportions(annotations)
+
+ # Paper formula:
+ # ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
+ # ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
+ assert abs(result["subtask1"] - 0.5) < 1e-6
+ assert abs(result["subtask2"] - 0.5) < 1e-6
+
+ def test_single_trajectory(self):
+ """Test with a single trajectory."""
+ # T=100s, reach=30s, grasp=20s, lift=50s
+ annotations = {
+ 0: make_annotation([("reach", 0, 30), ("grasp", 30, 50), ("lift", 50, 100)]),
+ }
+
+ result = compute_temporal_proportions(annotations)
+
+ assert abs(result["reach"] - 0.3) < 1e-6
+ assert abs(result["grasp"] - 0.2) < 1e-6
+ assert abs(result["lift"] - 0.5) < 1e-6
+
+ def test_sum_to_one(self):
+ """Test that proportions always sum to 1."""
+ # Three episodes with varying proportions
+ annotations = {
+ 0: make_annotation([("a", 0, 10), ("b", 10, 50), ("c", 50, 100)]), # 0.1, 0.4, 0.5
+ 1: make_annotation([("a", 0, 20), ("b", 20, 70), ("c", 70, 100)]), # 0.2, 0.5, 0.3
+ 2: make_annotation([("a", 0, 30), ("b", 30, 90), ("c", 90, 100)]), # 0.3, 0.6, 0.1
+ }
+
+ result = compute_temporal_proportions(annotations)
+
+ total = sum(result.values())
+ assert abs(total - 1.0) < 1e-6
+
+ def test_empty_annotations_returns_empty(self):
+ """Test that empty annotations returns empty dict."""
+ result = compute_temporal_proportions({})
+ assert result == {}
+
+ def test_uniform_proportions(self):
+ """Test with uniform proportions across subtasks."""
+ # Each subtask takes 25% of each episode
+ annotations = {
+ 0: make_annotation([("a", 0, 25), ("b", 25, 50), ("c", 50, 75), ("d", 75, 100)]),
+ 1: make_annotation([("a", 0, 50), ("b", 50, 100), ("c", 100, 150), ("d", 150, 200)]),
+ }
+
+ result = compute_temporal_proportions(annotations)
+
+ for name in ["a", "b", "c", "d"]:
+ assert abs(result[name] - 0.25) < 1e-6
diff --git a/tests/policies/test_sarm_utils.py b/tests/policies/test_sarm_utils.py
new file mode 100644
index 000000000..510477ec8
--- /dev/null
+++ b/tests/policies/test_sarm_utils.py
@@ -0,0 +1,615 @@
+#!/usr/bin/env python
+
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pytest
+import torch
+
+from lerobot.policies.sarm.sarm_utils import (
+ apply_rewind_augmentation,
+ compute_absolute_indices,
+ compute_tau,
+ find_stage_and_tau,
+ normalize_stage_tau,
+ temporal_proportions_to_breakpoints,
+)
+
+
+class TestProgressLabelsWithModes:
+ """End-to-end tests for progress label generation in different modes."""
+
+ def test_sparse_mode_single_stage(self):
+ """Sparse mode with single stage should give linear progress."""
+ episode_length = 300
+ global_names = ["task"]
+ proportions = {"task": 1.0}
+
+ # Test at various frames
+ for frame in [0, 100, 200, 299]:
+ stage, tau = find_stage_and_tau(
+ frame, episode_length, None, None, None, global_names, proportions
+ )
+
+ expected_tau = frame / (episode_length - 1)
+ assert stage == 0
+ assert abs(tau - expected_tau) < 1e-5
+
+ def test_sparse_mode_multi_stage(self):
+ """Sparse mode with multiple stages."""
+ global_names = ["reach", "grasp", "lift", "place"]
+ proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
+
+ subtask_names = ["reach", "grasp", "lift", "place"]
+ subtask_starts = [0, 60, 120, 210]
+ subtask_ends = [59, 119, 209, 299]
+
+ # Check stages are correctly identified
+ stage_at_30, _ = find_stage_and_tau(
+ 30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage_at_30 == 0
+
+ stage_at_90, _ = find_stage_and_tau(
+ 90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage_at_90 == 1
+
+ stage_at_150, _ = find_stage_and_tau(
+ 150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage_at_150 == 2
+
+ def test_dense_mode_more_stages(self):
+ """Dense mode should work with more fine-grained stages."""
+ global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
+ proportions = dict.fromkeys(global_names, 1 / 8)
+
+ subtask_names = global_names
+ subtask_starts = [i * 50 for i in range(8)]
+ subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
+
+ # Each stage should occupy 50 frames
+ for stage_idx in range(8):
+ mid_frame = stage_idx * 50 + 25
+ stage, _ = find_stage_and_tau(
+ mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == stage_idx
+
+
+class TestComputeAbsoluteIndices:
+ """Tests for compute_absolute_indices (bidirectional sampling)."""
+
+ def test_no_clamping_when_in_middle(self):
+ """When frame is in middle of episode, no clamping should occur."""
+ frame_idx = 300
+ ep_start = 0
+ ep_end = 1000
+ n_obs_steps = 8
+ frame_gap = 30
+
+ indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
+
+ # All should be valid (no out of bounds)
+ assert out_of_bounds.sum() == 0
+
+ # Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
+ half_steps = n_obs_steps // 2
+ expected = (
+ [frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
+ + [frame_idx]
+ + [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
+ )
+ assert indices.tolist() == expected
+
+ # Center frame (index 4) should be frame_idx
+ assert indices[half_steps] == frame_idx
+
+ def test_clamping_at_episode_start(self):
+ """Early frames should be clamped to episode start."""
+ frame_idx = 50 # Not enough history for full past window
+ ep_start = 0
+ ep_end = 1000
+ n_obs_steps = 8
+ frame_gap = 30
+
+ indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
+
+ # Some past frames should be clamped (out_of_bounds = 1)
+ assert out_of_bounds.sum() > 0
+
+ # All indices should be >= ep_start
+ assert (indices >= ep_start).all()
+
+ # Center index should be frame_idx
+ half_steps = n_obs_steps // 2
+ assert indices[half_steps] == frame_idx
+
+ def test_clamping_at_episode_end(self):
+ """Late frames should be clamped to episode end."""
+ frame_idx = 950 # Not enough future for full window
+ ep_start = 0
+ ep_end = 1000
+ n_obs_steps = 8
+ frame_gap = 30
+
+ indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
+
+ # Some future frames should be clamped
+ assert out_of_bounds.sum() > 0
+
+ # All indices should be < ep_end
+ assert (indices < ep_end).all()
+
+ # Center index should be frame_idx
+ half_steps = n_obs_steps // 2
+ assert indices[half_steps] == frame_idx
+
+ def test_sequence_is_monotonic(self):
+ """Frame indices should be monotonically increasing."""
+ for frame_idx in [50, 100, 300, 950]:
+ indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
+
+ # Check monotonic (non-decreasing due to clamping)
+ diffs = indices[1:] - indices[:-1]
+ assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
+
+
+class TestComputeTau:
+ """Tests for compute_tau (within-subtask progress).
+
+ Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
+ """
+
+ def test_at_start(self):
+ """τ should be 0 at subtask start."""
+ tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
+ assert tau == 0.0
+
+ def test_at_end(self):
+ """τ should be 1 at subtask end."""
+ tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
+ assert tau == 1.0
+
+ def test_at_middle(self):
+ """τ should be 0.5 at subtask midpoint."""
+ tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
+ assert abs(tau - 0.5) < 1e-6
+
+ def test_quarter_progress(self):
+ """Test τ at 25% through subtask."""
+ tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
+ assert abs(tau - 0.25) < 1e-6
+
+ def test_zero_duration_subtask(self):
+ """τ should be 1.0 for zero-duration subtask."""
+ tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
+ assert tau == 1.0
+
+ def test_clamps_below_zero(self):
+ """τ should be clamped to 0 if frame is before subtask."""
+ tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
+ assert tau == 0.0
+
+ def test_clamps_above_one(self):
+ """τ should be clamped to 1 if frame is after subtask."""
+ tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
+ assert tau == 1.0
+
+ def test_float_inputs(self):
+ """Test with float frame indices (from interpolation)."""
+ tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
+ expected = (25.5 - 10.0) / (50.0 - 10.0)
+ assert abs(tau - expected) < 1e-6
+
+
+class TestFindStageAndTau:
+ """Tests for find_stage_and_tau logic.
+
+ This function is the core of progress label computation. It determines
+ which stage a frame belongs to and the within-stage progress (tau).
+ """
+
+ def test_single_stage_mode_linear_progress(self):
+ """Single-stage mode should give linear progress from 0 to 1."""
+ episode_length = 100
+
+ # Frame 0 -> tau = 0
+ stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
+ assert stage == 0
+ assert abs(tau - 0.0) < 1e-6
+
+ # Frame 50 -> tau = 0.505 (50/99)
+ stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
+ assert stage == 0
+ assert abs(tau - 50 / 99) < 1e-6
+
+ # Frame 99 -> tau = 1.0
+ stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
+ assert stage == 0
+ assert abs(tau - 1.0) < 1e-6
+
+ def test_multi_stage_within_subtask(self):
+ """Test finding stage when frame is within a subtask."""
+ global_names = ["reach", "grasp", "lift"]
+ proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
+
+ subtask_names = ["reach", "grasp", "lift"]
+ subtask_starts = [0, 30, 50]
+ subtask_ends = [29, 49, 99]
+
+ # Frame 15 in "reach" stage (index 0)
+ stage, tau = find_stage_and_tau(
+ 15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 0
+ assert abs(tau - 15 / 29) < 1e-6
+
+ # Frame 40 in "grasp" stage (index 1)
+ stage, tau = find_stage_and_tau(
+ 40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 1
+ # tau = (40 - 30) / (49 - 30) = 10/19
+ assert abs(tau - 10 / 19) < 1e-6
+
+ # Frame 75 in "lift" stage (index 2)
+ stage, tau = find_stage_and_tau(
+ 75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 2
+ # tau = (75 - 50) / (99 - 50) = 25/49
+ assert abs(tau - 25 / 49) < 1e-6
+
+ def test_frame_at_subtask_boundaries(self):
+ """Test frames exactly at subtask boundaries."""
+ global_names = ["a", "b"]
+ proportions = {"a": 0.5, "b": 0.5}
+
+ subtask_names = ["a", "b"]
+ subtask_starts = [0, 50]
+ subtask_ends = [49, 99]
+
+ # Frame at start of first subtask
+ stage, tau = find_stage_and_tau(
+ 0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 0
+ assert tau == 0.0
+
+ # Frame at end of first subtask
+ stage, tau = find_stage_and_tau(
+ 49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 0
+ assert tau == 1.0
+
+ # Frame at start of second subtask
+ stage, tau = find_stage_and_tau(
+ 50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 1
+ assert tau == 0.0
+
+ def test_frame_after_last_subtask(self):
+ """Frames after last subtask should return last stage with high tau."""
+ global_names = ["a", "b"]
+ proportions = {"a": 0.5, "b": 0.5}
+
+ subtask_names = ["a", "b"]
+ subtask_starts = [0, 30]
+ subtask_ends = [29, 59]
+
+ # Frame 80 is after last subtask
+ stage, tau = find_stage_and_tau(
+ 80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
+ )
+ assert stage == 1 # Last stage
+ assert tau == 0.999 # Nearly complete
+
+
+class TestEndToEndProgressLabeling:
+ """End-to-end tests for progress label computation using normalize_stage_tau."""
+
+ def test_consistent_semantic_meaning(self):
+ """Test that same subtask completion maps to same progress across trajectories.
+
+ This is the key semantic property: "end of subtask 1" should always
+ mean the same progress value regardless of trajectory speed.
+ """
+ proportions = [0.3, 0.5, 0.2]
+
+ # Fast trajectory: subtask 1 ends at frame 30 (of 100)
+ tau_fast = compute_tau(30, 0, 30) # = 1.0
+ y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
+
+ # Slow trajectory: subtask 1 ends at frame 90 (of 300)
+ tau_slow = compute_tau(90, 0, 90) # = 1.0
+ y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
+
+ # Both should map to same progress (0.3 = end of subtask 1)
+ assert abs(y_fast - y_slow) < 1e-6
+ assert abs(y_fast - 0.3) < 1e-6
+
+ def test_monotonic_within_subtask(self):
+ """Test that progress is monotonically increasing within a subtask."""
+ proportions = [0.4, 0.6]
+
+ prev_y = -1
+ for tau in np.linspace(0, 1, 11):
+ y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
+ assert y > prev_y or (tau == 0 and y == 0)
+ prev_y = y
+
+ def test_continuous_across_subtasks(self):
+ """Test that progress is continuous at subtask boundaries."""
+ proportions = [0.3, 0.5, 0.2]
+
+ # End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
+ y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
+
+ # Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
+ y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
+
+ # Should be equal (P_1 = 0.3)
+ assert abs(y_end_0 - y_start_1) < 1e-6
+
+ # End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
+ y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
+
+ # Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
+ y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
+
+ # Should be equal (P_2 = 0.8)
+ assert abs(y_end_1 - y_start_2) < 1e-6
+
+
+class TestTemporalProportionsToBreakpoints:
+ """Tests for temporal_proportions_to_breakpoints.
+
+ Converts temporal proportions to cumulative breakpoints for normalization.
+ Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
+ """
+
+ def test_basic_conversion(self):
+ """Test basic conversion from proportions to breakpoints."""
+ proportions = [0.3, 0.5, 0.2]
+ breakpoints = temporal_proportions_to_breakpoints(proportions)
+
+ assert breakpoints is not None
+ assert len(breakpoints) == 4
+ assert breakpoints[0] == 0.0
+ assert abs(breakpoints[1] - 0.3) < 1e-6
+ assert abs(breakpoints[2] - 0.8) < 1e-6
+ assert breakpoints[3] == 1.0
+
+ def test_dict_input(self):
+ """Test with dict input."""
+ proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
+ breakpoints = temporal_proportions_to_breakpoints(proportions)
+
+ assert breakpoints is not None
+ assert len(breakpoints) == 4
+ assert breakpoints[0] == 0.0
+ assert breakpoints[-1] == 1.0
+
+ def test_dict_with_subtask_names_order(self):
+ """Test that subtask_names determines order for dict input."""
+ proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
+ subtask_names = ["a", "b", "c"] # Different order
+
+ breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
+
+ # Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
+ assert abs(breakpoints[1] - 0.2) < 1e-6 # a
+ assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
+ assert breakpoints[3] == 1.0 # a + b + c = 1.0
+
+ def test_uniform_proportions(self):
+ """Test with uniform proportions."""
+ proportions = [0.25, 0.25, 0.25, 0.25]
+ breakpoints = temporal_proportions_to_breakpoints(proportions)
+
+ expected = [0.0, 0.25, 0.5, 0.75, 1.0]
+ for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
+ assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
+
+ def test_none_input(self):
+ """Test that None input returns None."""
+ result = temporal_proportions_to_breakpoints(None)
+ assert result is None
+
+ def test_normalization(self):
+ """Test that non-normalized proportions are normalized."""
+ # Proportions sum to 2.0, not 1.0
+ proportions = [0.6, 1.0, 0.4]
+ breakpoints = temporal_proportions_to_breakpoints(proportions)
+
+ # Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
+ assert breakpoints[-1] == 1.0
+ assert abs(breakpoints[1] - 0.3) < 1e-6
+
+
+class TestNormalizeStageTau:
+ """Tests for normalize_stage_tau.
+
+ Normalizes stage+tau values to [0, 1] using breakpoints.
+ """
+
+ def test_linear_fallback(self):
+ """Test linear normalization when only num_stages is provided."""
+ # 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
+
+ # Stage 0 start
+ assert normalize_stage_tau(0.0, num_stages=4) == 0.0
+
+ # Stage 0 end / Stage 1 start
+ assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
+
+ # Stage 1 middle
+ assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
+
+ # Stage 3 end
+ assert normalize_stage_tau(4.0, num_stages=4) == 1.0
+
+ def test_with_custom_breakpoints(self):
+ """Test with custom breakpoints."""
+ # Non-linear breakpoints
+ breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
+
+ # Stage 0: maps [0, 1) to [0.0, 0.1)
+ assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
+
+ # Stage 1: maps [1, 2) to [0.1, 0.5)
+ assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
+
+ # Stage 2: maps [2, 3) to [0.5, 1.0)
+ assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
+
+ def test_with_temporal_proportions(self):
+ """Test with temporal proportions (auto-computed breakpoints)."""
+ proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
+ subtask_names = ["a", "b", "c"]
+
+ # Stage 0 end should map to 0.2
+ result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
+ assert abs(result - 0.2) < 1e-6
+
+ # Stage 1 end should map to 0.5
+ result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
+ assert abs(result - 0.5) < 1e-6
+
+ def test_tensor_input(self):
+ """Test with tensor input."""
+ x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
+ breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
+
+ result = normalize_stage_tau(x, breakpoints=breakpoints)
+
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == x.shape
+ assert abs(result[0].item() - 0.0) < 1e-6
+ assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
+ assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
+
+ def test_clamping(self):
+ """Test that output is clamped to [0, 1]."""
+ # Below 0
+ assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
+
+ # Above num_stages
+ assert normalize_stage_tau(5.0, num_stages=4) == 1.0
+
+ def test_batch_tensor(self):
+ """Test with batched tensor."""
+ x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
+
+ result = normalize_stage_tau(x, num_stages=3)
+
+ assert result.shape == (2, 3)
+ assert (result >= 0).all()
+ assert (result <= 1).all()
+
+ def test_requires_one_of_inputs(self):
+ """Test that at least one input method is required."""
+ with pytest.raises(ValueError):
+ normalize_stage_tau(1.0)
+
+
+class TestRewindAugmentation:
+ """Tests for rewind augmentation logic with bidirectional observation sampling.
+
+ Rewind appends frames before the earliest observation frame, going backwards.
+ With bidirectional sampling centered at frame_idx:
+ - Earliest obs frame = frame_idx - half_steps * frame_gap
+ - Rewind goes backwards from that point
+ """
+
+ def test_rewind_indices_go_backwards_from_earliest_obs(self):
+ """Rewind indices should go backwards from earliest observation frame."""
+ frame_idx = 300 # Center of bidirectional window
+ ep_start = 0
+ n_obs_steps = 4 # half_steps = 2
+ frame_gap = 30
+
+ # Earliest obs frame = 300 - 2*30 = 240
+ # Rewind goes backwards: 210, 180
+ rewind_step, rewind_indices = apply_rewind_augmentation(
+ frame_idx,
+ ep_start,
+ n_obs_steps=n_obs_steps,
+ max_rewind_steps=2,
+ frame_gap=frame_gap,
+ rewind_step=2,
+ )
+
+ assert rewind_step == 2
+ assert len(rewind_indices) == 2
+ # First rewind frame is closest to obs window, second is further back
+ assert rewind_indices[0] == 210 # 240 - 30
+ assert rewind_indices[1] == 180 # 240 - 60
+ assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
+
+ def test_rewind_goes_backward_through_history(self):
+ """Rewind frames should go backward before the observation window."""
+ frame_idx = 450 # Center of bidirectional window
+ ep_start = 0
+ n_obs_steps = 8 # half_steps = 4
+ frame_gap = 30
+
+ # Earliest obs frame = 450 - 4*30 = 330
+ # Rewind from 330: [300, 270, 240]
+ rewind_step, rewind_indices = apply_rewind_augmentation(
+ frame_idx,
+ ep_start,
+ n_obs_steps=n_obs_steps,
+ max_rewind_steps=4,
+ frame_gap=frame_gap,
+ rewind_step=3,
+ )
+
+ assert rewind_step == 3
+ expected = [300, 270, 240] # Going backwards from 330
+ assert rewind_indices == expected
+
+ def test_no_rewind_when_obs_window_at_episode_start(self):
+ """No rewind when observation window reaches episode start."""
+ frame_idx = 120 # Center of window
+ ep_start = 0
+ n_obs_steps = 8 # half_steps = 4
+ frame_gap = 30
+
+ # Earliest obs frame = 120 - 4*30 = 0 (at episode start)
+ rewind_step, rewind_indices = apply_rewind_augmentation(
+ frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
+ )
+
+ # No room for rewind
+ assert rewind_step == 0
+ assert rewind_indices == []
+
+ def test_rewind_targets_are_decreasing(self):
+ """Progress targets for rewind frames should be decreasing."""
+ # Simulate progress values
+ obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
+
+ # Rewind reverses progress
+ rewind_indices = [4, 3, 2] # Go backwards through indices
+ rewind_progress = [obs_progress[i] for i in rewind_indices]
+
+ # Should be decreasing
+ for i in range(len(rewind_progress) - 1):
+ assert rewind_progress[i] > rewind_progress[i + 1]