diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index aae7372fa..85a79ef17 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -42,6 +42,10 @@ - local: xvla title: X-VLA title: "Policies" +- sections: + - local: sarm + title: SARM + title: "Reward Models" - sections: - local: async title: Use Async Inference diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx new file mode 100644 index 000000000..321097692 --- /dev/null +++ b/docs/source/sarm.mdx @@ -0,0 +1,586 @@ +# SARM: Stage-Aware Reward Modeling + +SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC). + +**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358) + +## Why Reward Models? + +Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement. + +## Overview + +SARM has following features: + +1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage +2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels +3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations + +SARM trains on a compact **stage+tau** target for each frame: + +- **stage**: integer stage index `k ∈ {0, ..., K-1}` +- **τ (tau)**: within-stage progress `τ ∈ [0, 1]` +- **target encoding**: `y = k + τ` (this is what the dataset processor produces) + +At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`). + +This matches **Formula (2)** from the paper: + +``` +progress_t = P_{k-1} + α̅_k × τ_t +``` + +Where: + +- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time +- `P_{k-1}` is cumulative prior (sum of previous subtask proportions) +- `α̅_k` is the temporal proportion for subtask k + +This ensures identical task states map to consistent progress values, even across demonstrations of different lengths. + +## Inputs and Targets (What the new code expects) + +SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which: + +- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features` +- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`) +- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ` +- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation) + +At minimum, each training sample needs: + +- `task` (string): task description +- `policy.image_key` images and `policy.state_key` states from the dataset + +--- + +## Annotation Modes + +You can choose from **3 annotation modes** that determine how progress labels are computed: + +| Mode | Annotations Required | Heads | Use Case | +| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ | +| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed | +| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages | +| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities | + +### Mode Details + + + + +**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration. + +- **Sparse head**: 1 stage ("task"), linear progress +- **Dense head**: Not used +- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available + +## Set Up Your Environment + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install SARM dependencies by running: + +```bash +pip install -e ".[sarm]" +``` + +Workflow: + +``` +1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC +``` + + + + +**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression. + +- **Sparse head**: 1 stage ("task"), linear progress (auto-generated) +- **Dense head**: Multiple fine-grained stages from VLM annotations +- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages + +Workflow: + +``` +1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC +``` + + + + +**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions. + +- **Sparse head**: High-level stages from VLM annotations +- **Dense head**: Fine-grained stages from VLM annotations +- **Best for**: Complex multi-stage tasks where both granularities are useful + +Workflow: + +``` +1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC +``` + + + + +--- + +## Step 1: Subtask Annotation + + + + +**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically. + + + + +Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated. + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --dense-only \ + --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \ + --video-key observation.images.base \ + --num-workers 4 \ + --push-to-hub +``` + +**What gets saved:** + +- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`) +- `meta/temporal_proportions_dense.json` - Dense temporal proportions +- Per-episode columns in `episodes/*.parquet`: + - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames` + - (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`) + + + + +Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM. + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \ + --dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \ + --video-key observation.images.base \ + --num-workers 4 \ + --push-to-hub +``` + +**What gets saved:** + +- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions +- `meta/temporal_proportions_dense.json` - Dense temporal proportions +- Per-episode columns in `episodes/*.parquet`: + - `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames` + - `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames` + - (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`) + + + + +### Annotation Arguments + +| Argument | Description | +| ---------------------- | ------------------------------------------------------------------------------- | +| `--repo-id` | HuggingFace dataset repository ID | +| `--sparse-subtasks` | Comma-separated list of high-level subtask names | +| `--dense-subtasks` | Comma-separated list of fine-grained subtask names | +| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) | +| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) | +| `--num-workers` | Number of parallel GPU workers (default: 1) | +| `--episodes` | Specific episode indices to annotate (default: all) | +| `--skip-existing` | Skip episodes that already have annotations | +| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) | +| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) | + +> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step. + +--- + +## Step 2: Verify Annotations + + + + +**No verification needed!** Skip this step. + + + + +Visualize annotations using the `--visualize-only` flag: + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --visualize-only \ + --visualize-type dense \ + --num-visualizations 5 \ + --video-key observation.images.base \ + --output-dir ./subtask_viz +``` + + + + +Visualize annotations using the `--visualize-only` flag: + +```bash +python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --visualize-only \ + --visualize-type both \ + --num-visualizations 5 \ + --video-key observation.images.base \ + --output-dir ./subtask_viz +``` + + + + +This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks. + +### Visualization Arguments + +| Argument | Description | +| ---------------------- | -------------------------------------------------------------- | +| `--visualize-only` | Only visualize existing annotations (no generation) | +| `--num-visualizations` | Number of episodes to visualize (default: 5) | +| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` | + +**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run. + +--- + +## Step 3: Train SARM + + + + +Train with **no annotations** - uses linear progress from 0 to 1: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=single_stage \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_single \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +Train with **dense annotations only** (sparse auto-generated): + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=dense_only \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_dense \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +Train with **both sparse and dense annotations**: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=sarm \ + --policy.annotation_mode=dual \ + --policy.image_key=observation.images.base \ + --output_dir=outputs/train/sarm_dual \ + --batch_size=32 \ + --steps=5000 \ + --wandb.enable=true \ + --wandb.project=sarm \ + --policy.repo_id=your-username/your-model-name +``` + + + + +### Multi-GPU Training + +Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training. + +### Training Arguments + +| Argument | Description | Default | +| -------------------------- | ----------------------------------------------------------------- | ------------------------ | +| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` | +| `--policy.image_key` | Camera key for images | `observation.images.top` | +| `--policy.state_key` | Key for joint states | `observation.state` | +| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` | +| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` | + +--- + +## Step 4: Visualize Predictions + +Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file. + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode sparse \ + --output-dir ./sarm_viz +``` + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode dense \ + --output-dir ./sarm_viz +``` + + + + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --visualize-only \ + --num-visualizations 5 \ + --head-mode both \ + --output-dir ./sarm_viz +``` + + + + +The visualization shows: + +- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`) +- **Stage probabilities**: Stacked area plot of predicted stage probabilities +- **Sample frames**: Key frames from the episode with progress/stage labels + +### Visualization Arguments + +| Argument | Description | +| ---------------------- | --------------------------------------------------------- | +| `--visualize-only` | Only visualize predictions (no RABC computation) | +| `--num-visualizations` | Number of episodes to visualize (default: 5) | +| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | +| `--stride` | Compute every N frames, interpolate the rest (default: 1) | + +--- + +## Step 5 (Optional): Train Policy with RA-BC + +Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps: + +1. **Precompute progress values** for all frames using the trained SARM model +2. **Train policy** with RA-BC weighting using the precomputed values + +### How RA-BC Works + +For each training sample, RA-BC computes the progress delta: + +``` +r_i = φ(o_{t+Δ}) - φ(o_t) +``` + +Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted. + +The weighting follows **Equations 8-9** from the paper: + +- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)` +- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i` + +### Step 5a: Compute SARM Progress Values + +First, run the SARM model on all frames in your dataset to compute progress values: + +```bash +python src/lerobot/policies/sarm/compute_rabc_weights.py \ + --dataset-repo-id your-username/your-dataset \ + --reward-model-path your-username/sarm-model \ + --head-mode sparse \ + --num-visualizations 5 \ + --push-to-hub +``` + +This script: + +- Processes all frames and computes progress values +- Saves progress values to a parquet file next to the dataset on disk (defaults to `/sarm_progress.parquet`) +- Generates visualizations of the first N episodes (default: 5) + +**Arguments:** + +| Argument | Description | Default | +| ---------------------- | -------------------------------------------------------------- | ---------- | +| `--reward-model-path` | Path to trained SARM model | (required) | +| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` | +| `--device` | Device for inference | `cuda` | +| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` | +| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` | + +**Output format** (`sarm_progress.parquet`): + +| Column | Description | +| ----------------- | ---------------------------------------------- | +| `index` | Global frame index in dataset | +| `episode_index` | Episode number | +| `frame_index` | Local frame index within episode | +| `progress_sparse` | Sparse head progress value [0, 1] | +| `progress_dense` | Dense head progress value [0, 1] (if computed) | + +### Step 5b: Train Policy with RA-BC + +Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=pi0 \ + --use_rabc=true \ + --rabc_head_mode=sparse \ + --rabc_kappa=0.01 \ + --output_dir=outputs/train/policy_rabc \ + --batch_size=32 \ + --steps=40000 +``` + +The training script automatically: + +- Loads the precomputed progress values from the parquet file +- Uses the policy's `chunk_size` to compute progress deltas (Δ) +- Computes sample weights based on progress improvement +- Applies weighted loss during training + +**RA-BC Arguments:** + +| Argument | Description | Default | +| ---------------------- | ---------------------------------------------------------- | ---------------------------------- | +| `--use_rabc` | Enable RA-BC sample weighting | `false` | +| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset | +| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | +| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` | + +### Tuning RA-BC Kappa + +The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively. + +**How the weighting works:** + +| Condition | Weight | +| ------------------- | ----------------------- | +| `delta > kappa` | 1.0 (hard threshold) | +| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 | +| `delta < 0` | 0.0 (negative progress) | + +**Diagnosing kappa issues:** + +Monitor these WandB metrics during training: + +| Metric | Healthy Range | Problem Indicator | +| ------------------ | ------------- | ------------------------- | +| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | +| `rabc_delta_mean` | > 0 | Should be positive | +| `rabc_delta_std` | > 0 | Variance in data quality | + +**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. + +**Setting kappa based on your data:** + +The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`: + +``` +# If delta_mean ≈ 0.03 and delta_std ≈ 0.02: +# Most deltas fall in range [0.01, 0.05] + +# Option 1: Set kappa = delta_mean (medium selectivity) +--rabc_kappa=0.03 + +# Option 2: Set kappa = delta_mean + delta_std (high selectivity) +--rabc_kappa=0.05 + +# Option 3: Set kappa = delta_mean + 2*delta_std (very selective) +--rabc_kappa=0.07 +``` + +**When RA-BC may not help:** + +If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter. + +### Multi-GPU Training with RA-BC + +```bash +accelerate launch \ + --multi_gpu \ + --num_processes=4 \ + src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your-username/your-dataset \ + --policy.type=pi0 \ + --use_rabc=true \ + --rabc_kappa=0.01 \ + --output_dir=outputs/train/policy_rabc \ + --batch_size=32 \ + --steps=40000 +``` + +--- + +## Tips & Best Practices + +### Choosing a Mode + +- **Start with `single_stage`** for quick experiments - no annotation overhead +- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages +- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful + +### Annotation Quality + +1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center" +2. **Verify with visualization**: Always check a few episodes before training +3. **Consistent naming**: Use the same subtask names across all episodes + +### RA-BC + +1. **Train SARM first**: RA-BC quality depends entirely on SARM quality +2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) + +--- + +## Citation + +```bibtex +@article{chen2025sarm, + title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation}, + author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp}, + journal={arXiv preprint arXiv:2509.25358}, + year={2025} +} +``` diff --git a/pyproject.toml b/pyproject.toml index 9458c0127..ed3e2ae43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] -transformers-dep = ["transformers>=4.53.0,<5.0.0"] +transformers-dep = ["transformers>=4.57.1,<5.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) # Motors @@ -133,6 +133,7 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] +sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] @@ -173,6 +174,7 @@ all = [ "lerobot[phone]", "lerobot[libero]", "lerobot[metaworld]", + "lerobot[sarm]" ] [project.scripts] diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 13a8d6525..cee9dfdf9 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -65,9 +65,17 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) - checkpoint_path: Path | None = field(init=False, default=None) + + # RA-BC (Reward-Aligned Behavior Cloning) parameters + use_rabc: bool = False # Enable reward-weighted training + rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file + rabc_kappa: float = 0.01 # Hard threshold for high-quality samples + rabc_epsilon: float = 1e-6 # Small constant for numerical stability + rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense" + # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) + checkpoint_path: Path | None = field(init=False, default=None) def validate(self) -> None: # HACK: We parse again the cli args here to get the pretrained paths if there was some. @@ -131,6 +139,14 @@ class TrainPipelineConfig(HubMixin): "'policy.repo_id' argument missing. Please specify it to push the model to the hub." ) + if self.use_rabc and not self.rabc_progress_path: + # Auto-detect from dataset path + repo_id = self.dataset.repo_id + if self.dataset.root: + self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet") + else: + self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet" + @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" diff --git a/src/lerobot/data_processing/__init__.py b/src/lerobot/data_processing/__init__.py new file mode 100644 index 000000000..2f76d5676 --- /dev/null +++ b/src/lerobot/data_processing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/data_processing/sarm_annotations/__init__.py b/src/lerobot/data_processing/sarm_annotations/__init__.py new file mode 100644 index 000000000..2f76d5676 --- /dev/null +++ b/src/lerobot/data_processing/sarm_annotations/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py new file mode 100644 index 000000000..67e37bab8 --- /dev/null +++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py @@ -0,0 +1,1202 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM Subtask Annotation using local GPU (Qwen3-VL). + +This script implements the annotation approach from the SARM paper using local GPU inference: +"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation" +Paper: https://arxiv.org/pdf/2509.25358 + +What it does: +1. Takes videos from a LeRobot dataset +2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur +3. Saves subtask timestamps to the dataset metadata +4. Optionally pushes the annotated dataset to HuggingFace Hub + +SARM trains reward models that predict: + - Stage: Which subtask is currently being executed (discrete classification) + - Progress: How far along the subtask we are (continuous 0-1) + +Supports three annotation modes: + 1. No annotations (no args): Auto-creates single sparse "task" stage covering full episode. + Use with SARM config annotation_mode="single_stage" for simple tasks. + + 2. Dense-only (--dense-only --dense-subtasks): Dense annotations from VLM, auto-generated + single sparse "task" stage. Use with annotation_mode="dense_only". + + 3. Dual mode (--sparse-subtasks + --dense-subtasks): Both sparse and dense annotations + from VLM. Use with annotation_mode="dual". + +Requirements: + - GPU with sufficient VRAM (16GB+ recommended for 30B model) + - `pip install transformers, torch, qwen-vl-utils` + +Run with: +```bash +python examples/dataset_annotation/subtask_annotation.py \ + --repo-id your-username/your-dataset \ + --sparse-subtasks "Do ..." \ + --dense-subtasks "Do task 1, Do task 2, Do task 3" \ + --video-key observation.images.base \ + --push-to-hub +``` +""" + +import argparse +import json +import multiprocessing as mp +import random +import re +import subprocess +import tempfile +import textwrap +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +import pandas as pd +import torch +from pydantic import BaseModel, Field +from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +# Pydantic Models for SARM Subtask Annotation +class Timestamp(BaseModel): + """Timestamp in MM:SS or SS format""" + + start: str = Field(description="Start timestamp (MM:SS or just seconds)") + end: str = Field(description="End timestamp (MM:SS or just seconds)") + + +class Subtask(BaseModel): + """Individual subtask/stage - must use EXACT names from provided list""" + + name: str = Field(description="Subtask name - MUST match one from the predefined list exactly") + timestamps: Timestamp + + +class SubtaskAnnotation(BaseModel): + """Complete annotation for a robot manipulation episode""" + + subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order") + + +def compute_temporal_proportions( + annotations: dict[int, Any], fps: int = 30, subtask_order: list[str] | None = None +) -> dict[str, float]: + """ + Compute dataset-level temporal proportions (priors) for each subtask. + + Implements SARM Paper Formula (1): ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + + Args: + annotations: Dict mapping episode index to SubtaskAnnotation object. + fps: Frames per second (unused, kept for API compatibility) + subtask_order: Optional list defining the output order of subtasks. + + Returns: + Dict mapping subtask name to its temporal proportion (ᾱ_k), ordered by subtask_order if provided. + """ + subtask_proportions: dict[str, list[float]] = {} + + for annotation in annotations.values(): + total_duration = 0 + durations: dict[str, int] = {} + + for subtask in annotation.subtasks: + start_parts = subtask.timestamps.start.split(":") + end_parts = subtask.timestamps.end.split(":") + + start_seconds = ( + int(start_parts[0]) * 60 + int(start_parts[1]) + if len(start_parts) == 2 + else int(start_parts[0]) + ) + end_seconds = ( + int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0]) + ) + + duration = end_seconds - start_seconds + durations[subtask.name] = duration + total_duration += duration + + if total_duration > 0: + for name, duration in durations.items(): + if name not in subtask_proportions: + subtask_proportions[name] = [] + subtask_proportions[name].append(duration / total_duration) + + if not subtask_proportions: + return {} + + avg_proportions = {name: sum(props) / len(props) for name, props in subtask_proportions.items()} + + total = sum(avg_proportions.values()) + if total > 0: + avg_proportions = {name: prop / total for name, prop in avg_proportions.items()} + + # Reorder according to subtask_order if provided + if subtask_order: + avg_proportions = { + name: avg_proportions.get(name, 0.0) for name in subtask_order if name in avg_proportions + } + + return avg_proportions + + +def create_sarm_prompt(subtask_list: list[str]) -> str: + subtask_str = "\n".join([f" - {name}" for name in subtask_list]) + + return textwrap.dedent(f"""\ + # Role + You are a Robotics Vision System specializing in temporal action localization for robot manipulation. Your job is to segment a single demonstration video into distinct, non-overlapping atomic actions from a fixed subtask list. + + # Subtask Label Set (Closed Vocabulary) + You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones: + + [ + {subtask_str} + ] + + The video shows one successful execution of all subtasks in a logical order. + + # Ground-Truth Semantics (Very Important) + Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks. + + - A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask. + - A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration. + + If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask. + + # Hard Constraints & Logic + 1. **Continuous Coverage (No Gaps):** + - The entire video duration from "00:00" to the final timestamp must be covered by subtasks. + - There can be no gaps between subtasks. + - If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it. + + 2. **Boundary Consistency:** + - The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask. + - Boundaries must coincide with a real visual state transition, not just a convenient time split. + + 3. **Chronological Order, One Occurrence Each:** + - This is a single successful demonstration. + - Each subtask from the vocabulary appears **exactly once**, in the correct logical order. + - **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video. + + 4. **Reject Uniform Segmentation (Important):** + - Do NOT simply divide the video into equal or nearly equal time chunks. + - If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries. + - Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare). + + 5. **Timestamps:** + - Timestamps must be in `"MM:SS"` format. + - The first subtask always starts at `"00:00"`. + - The last subtask ends at the final visible frame of the video. + + # Step 1 — Textual Timeline (must do this first) + First, write a extensive and detailed textual timeline describing what happens in the video with approximate timestamps. + For each subtask, include: + - its name + - an approximate start and end time, + - an description of the visual event at the boundary (e.g. "shirt fully folded to the left", "robot rotates folded shirt 90 degrees"). + + Format this as a bullet list. + + # Step 2 — JSON Output (final answer) + After the textual timeline, output **only** valid JSON with this structure. + The JSON **must** be consistent with the textual timeline above: + + {{ + "subtasks": [ + {{ + "name": "EXACT_NAME_FROM_LIST", + "timestamps": {{ + "start": "MM:SS", + "end": "MM:SS" + }} + }}, + {{ + "name": "EXACT_NAME_FROM_LIST", + "timestamps": {{ + "start": "MM:SS", + "end": "MM:SS" + }} + }} + ] + }} + + Do not add any extra keys to the JSON. + """) + + +class VideoAnnotator: + """Annotates robot manipulation videos using local Qwen3-VL model on GPU""" + + def __init__( + self, + subtask_list: list[str], + model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct", + device: str = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + model: Qwen3VLMoeForConditionalGeneration | None = None, # noqa: F821 + processor: AutoProcessor | None = None, # noqa: F821 + ): + """ + Initialize the video annotator with local model. + + Args: + subtask_list: List of allowed subtask names (for consistency) + model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct) + device: Device to use (cuda, cpu) + torch_dtype: Data type for model (bfloat16, float16, float32) + model: Pre-loaded model instance (optional, to share between annotators) + processor: Pre-loaded processor instance (optional, to share between annotators) + """ + self.subtask_list = subtask_list + self.prompt = create_sarm_prompt(subtask_list) + self.device = device + + # Use provided model/processor or load new ones + if model is not None and processor is not None: + self.model = model + self.processor = processor + print(f"Using shared model on {device}") + else: + from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + print(f"Loading model: {model_name}...") + + self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + print(f"Model loaded successfully on {device}") + + def extract_episode_segment( + self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1 + ) -> Path: + """ + Extract a specific episode segment from concatenated video. + Uses minimal compression to preserve quality for local inference. + + Args: + file_path: Path to the concatenated video file + start_timestamp: Starting timestamp in seconds (within this video file) + end_timestamp: Ending timestamp in seconds (within this video file) + target_fps: Target FPS (default: 1 for faster processing) + + Returns: + Path to extracted video file + """ + # Create temporary file for extracted video + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + try: + # Check if ffmpeg is available + subprocess.run( # nosec B607 + ["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True + ) + except (subprocess.CalledProcessError, FileNotFoundError) as err: + raise RuntimeError("ffmpeg not found, cannot extract episode segment") from err + + try: + # Calculate duration + duration = end_timestamp - start_timestamp + + print(f"Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)") + + # Use ffmpeg to extract segment with minimal quality loss + cmd = [ + "ffmpeg", + "-i", + str(file_path), + "-ss", + str(start_timestamp), + "-t", + str(duration), + "-r", + str(target_fps), + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-an", + "-y", + str(tmp_path), + ] + + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + + # Verify the output file was created and is not empty + if not tmp_path.exists() or tmp_path.stat().st_size == 0: + print("Video extraction failed (0 bytes) - skipping episode") + if tmp_path.exists(): + tmp_path.unlink() + raise RuntimeError("FFmpeg produced empty video file") + + # Show extraction results + file_size_mb = tmp_path.stat().st_size / (1024 * 1024) + + # Fail if file is too small (< 100KB likely means extraction failed) + if file_size_mb < 0.1: + print(f"Extracted video too small ({file_size_mb:.2f}MB) - skipping episode") + tmp_path.unlink() + raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)") + + print(f"Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)") + + return tmp_path + + except subprocess.CalledProcessError as e: + raise RuntimeError(f"ffmpeg failed ({e})") from e + + def annotate( + self, + file_path: str | Path, + fps: int, + start_timestamp: float = 0.0, + end_timestamp: float | None = None, + max_retries: int = 3, + ) -> SubtaskAnnotation: + """Annotate a video segment using local GPU.""" + from qwen_vl_utils import process_vision_info + + file_path = Path(file_path) + + if end_timestamp is None: + cap = cv2.VideoCapture(str(file_path)) + end_timestamp = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / (cap.get(cv2.CAP_PROP_FPS) or 1) + cap.release() + + duration = end_timestamp - start_timestamp + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + + extracted_path = self.extract_episode_segment(file_path, start_timestamp, end_timestamp, 1) + is_extracted = extracted_path != file_path + + try: + messages = [ + {"role": "system", "content": [{"type": "text", "text": self.prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(extracted_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video is {duration_str} (~{duration:.1f}s). Follow instructions.", + }, + ], + }, + ] + + for attempt in range(max_retries): + try: + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, max_new_tokens=1024, do_sample=True, temperature=0.7 + ) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)], + skip_special_tokens=True, + )[0].strip() + + # Extract JSON + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + return SubtaskAnnotation.model_validate(json.loads(response)) + except json.JSONDecodeError: + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + return SubtaskAnnotation.model_validate(json.loads(match.group())) + raise ValueError("No JSON found") from None + except Exception as e: + if attempt == max_retries - 1: + raise RuntimeError(f"Failed after {max_retries} attempts") from e + time.sleep(1) + finally: + if is_extracted and extracted_path.exists(): + extracted_path.unlink() + + +def display_annotation(annotation: SubtaskAnnotation, episode_idx: int, fps: int, prefix: str = ""): + """Display annotation summary.""" + subtask_summary = ", ".join( + f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks + ) + print(f"Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}") + + +def timestamp_to_seconds(timestamp: str) -> float: + """Convert MM:SS or SS timestamp to seconds""" + parts = timestamp.split(":") + if len(parts) == 2: + return int(parts[0]) * 60 + int(parts[1]) + else: + return int(parts[0]) + + +def extract_frame(video_path: Path, timestamp: float) -> np.ndarray | None: + """Extract a single frame from video at given timestamp.""" + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + return None + cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) + ret, frame = cap.read() + cap.release() + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if ret else None + + +def draw_timeline(ax, subtasks, total_duration, colors): + """Draw a timeline with color-coded subtask segments.""" + import matplotlib.patches as mpatches + + bar_height, bar_y = 0.6, 0.5 + + for i, subtask in enumerate(subtasks): + start = timestamp_to_seconds(subtask.timestamps.start) + end = timestamp_to_seconds(subtask.timestamps.end) + color = colors[i % len(colors)] + + rect = mpatches.FancyBboxPatch( + (start, bar_y - bar_height / 2), + end - start, + bar_height, + boxstyle="round,pad=0.02,rounding_size=0.1", + facecolor=color, + edgecolor="white", + linewidth=1.5, + alpha=0.85, + ) + ax.add_patch(rect) + + # Add label if segment is wide enough + duration = end - start + if duration > total_duration * 0.06: + ax.text( + (start + end) / 2, + bar_y, + subtask.name, + ha="center", + va="center", + fontsize=8, + fontweight="bold", + color="white", + rotation=0 if duration > total_duration * 0.12 else 45, + ) + + if i > 0: + ax.axvline(x=start, ymin=0.1, ymax=0.9, color="white", linestyle="--", linewidth=1.5, alpha=0.7) + + ax.axvline(x=0, ymin=0.1, ymax=0.9, color="#00ff00", linestyle="-", linewidth=2, alpha=0.9) + if subtasks: + ax.axvline( + x=timestamp_to_seconds(subtasks[-1].timestamps.end), + ymin=0.1, + ymax=0.9, + color="white", + linestyle="--", + linewidth=1.5, + alpha=0.7, + ) + + ax.set_xlim(-total_duration * 0.02, total_duration * 1.02) + ax.set_ylim(-0.1, 1.1) + ax.set_xlabel("Time (seconds)", fontsize=10, color="white", labelpad=5) + for spine in ["top", "right", "left"]: + ax.spines[spine].set_visible(False) + ax.spines["bottom"].set_color("#444444") + ax.tick_params(axis="x", colors="#888888", labelsize=8) + ax.tick_params(axis="y", left=False, labelleft=False) + + +def visualize_episode( + ep_idx: int, + annotation: SubtaskAnnotation, + video_path: Path, + video_start: float, + video_end: float, + output_path: Path, + video_key: str, + ann_type: str, +): + """Create visualization for a single episode with frames and timeline.""" + import matplotlib.pyplot as plt + + if annotation is None: + print(f"No {ann_type} annotation for episode {ep_idx}") + return + + subtasks = annotation.subtasks + if not subtasks: + print(f"No subtasks for episode {ep_idx}") + return + + colors = plt.cm.tab10(np.linspace(0, 1, max(len(subtasks), 10))) + total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end) + + # Extract middle frame from each subtask + sample_frames, frame_times = [], [] + for subtask in subtasks: + start = timestamp_to_seconds(subtask.timestamps.start) + end = timestamp_to_seconds(subtask.timestamps.end) + mid = (start + end) / 2 + frame_times.append(mid) + sample_frames.append(extract_frame(video_path, video_start + mid)) + + # Create figure + fig_width = max(16, len(subtasks) * 2.5) + fig = plt.figure(figsize=(fig_width, 10)) + fig.patch.set_facecolor("#1a1a2e") + + gs = fig.add_gridspec( + 2, + max(len(subtasks), 1), + height_ratios=[2, 1], + hspace=0.3, + wspace=0.1, + left=0.05, + right=0.95, + top=0.88, + bottom=0.1, + ) + + fig.suptitle( + f"Episode {ep_idx} - {ann_type.capitalize()} Annotations", + fontsize=18, + fontweight="bold", + color="white", + y=0.96, + ) + fig.text( + 0.5, + 0.91, + f"Camera: {video_key} | Duration: {video_end - video_start:.1f}s | {len(subtasks)} subtasks", + ha="center", + fontsize=11, + color="#888888", + ) + + # Plot frames + for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks, strict=True)): + ax = fig.add_subplot(gs[0, i]) + ax.set_facecolor("#16213e") + if frame is not None: + ax.imshow(frame) + else: + ax.text( + 0.5, 0.5, "N/A", ha="center", va="center", fontsize=12, color="white", transform=ax.transAxes + ) + ax.set_title(subtask.name, fontsize=10, fontweight="bold", color=colors[i % len(colors)], pad=8) + ax.axis("off") + ax.text( + 0.5, + -0.08, + f"t={frame_times[i]:.1f}s", + ha="center", + fontsize=9, + color="#888888", + transform=ax.transAxes, + ) + + # Plot timeline + ax_timeline = fig.add_subplot(gs[1, :]) + ax_timeline.set_facecolor("#16213e") + draw_timeline(ax_timeline, subtasks, total_duration, colors) + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor="none", bbox_inches="tight") + plt.close() + print(f"Saved: {output_path}") + + +def visualize_annotations( + dataset: LeRobotDataset, + sparse_annotations: dict[int, SubtaskAnnotation], + dense_annotations: dict[int, SubtaskAnnotation] | None, + video_key: str, + output_dir: Path, + num_episodes: int = 5, + annotation_type: str = "sparse", + episode_indices: list[int] | None = None, +): + """ + Visualize subtask annotations for a set of episodes. + + Args: + dataset: LeRobotDataset instance + sparse_annotations: Dict mapping episode index to sparse annotations + dense_annotations: Dict mapping episode index to dense annotations (or None) + video_key: Camera/video key to use + output_dir: Directory to save visualization images + num_episodes: Number of episodes to visualize (ignored if episode_indices provided) + annotation_type: "sparse", "dense", or "both" + episode_indices: Specific episode indices to visualize (optional) + """ + # Determine available episodes based on annotation type + if annotation_type == "sparse": + available = set(sparse_annotations.keys()) + elif annotation_type == "dense": + available = set(dense_annotations.keys()) if dense_annotations else set() + else: # both + sparse_set = set(sparse_annotations.keys()) + dense_set = set(dense_annotations.keys()) if dense_annotations else set() + available = sparse_set | dense_set + + if not available: + print("Error: No annotations found to visualize.") + return + + # Select episodes to visualize + if episode_indices: + episodes = sorted([e for e in episode_indices if e in available]) + missing = set(episode_indices) - available + if missing: + print(f"Episodes not found in annotations: {sorted(missing)}") + else: + episodes = sorted(random.sample(list(available), min(num_episodes, len(available)))) + print(f"Visualizing {len(episodes)} episodes: {episodes}") + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate visualizations + for i, ep_idx in enumerate(episodes, 1): + print(f"Processing episode {ep_idx} ({i}/{len(episodes)})") + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) + if not video_path.exists(): + print(f"Video not found: {video_path}") + continue + + video_start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + video_end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + + if annotation_type == "both": + # Visualize both sparse and dense + for ann_type, annotations in [("sparse", sparse_annotations), ("dense", dense_annotations)]: + if annotations and ep_idx in annotations: + output_path = output_dir / f"episode_{ep_idx:04d}_{ann_type}.png" + visualize_episode( + ep_idx, + annotations.get(ep_idx), + video_path, + video_start, + video_end, + output_path, + video_key, + ann_type, + ) + else: + annotations = sparse_annotations if annotation_type == "sparse" else dense_annotations + if annotations and ep_idx in annotations: + output_path = output_dir / f"episode_{ep_idx:04d}_{annotation_type}.png" + visualize_episode( + ep_idx, + annotations.get(ep_idx), + video_path, + video_start, + video_end, + output_path, + video_key, + annotation_type, + ) + + print(f"Visualizations saved to: {output_dir.absolute()}") + + +def save_annotations_to_dataset( + dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse" +): + """Save annotations to LeRobot dataset parquet format.""" + from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes + + episodes_dataset = load_episodes(dataset_path) + if not episodes_dataset or len(episodes_dataset) == 0: + return + + episodes_df = episodes_dataset.to_pandas() + cols = [ + f"{prefix}_{c}" + for c in [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + ] + for col in cols: + episodes_df[col] = None + + for ep_idx, ann in annotations.items(): + if ep_idx >= len(episodes_df): + continue + names, starts, ends, start_frames, end_frames = [], [], [], [], [] + for s in ann.subtasks: + names.append(s.name) + st, et = timestamp_to_seconds(s.timestamps.start), timestamp_to_seconds(s.timestamps.end) + starts.append(st) + ends.append(et) + start_frames.append(int(st * fps)) + end_frames.append(int(et * fps)) + episodes_df.at[ep_idx, cols[0]] = names + episodes_df.at[ep_idx, cols[1]] = starts + episodes_df.at[ep_idx, cols[2]] = ends + episodes_df.at[ep_idx, cols[3]] = start_frames + episodes_df.at[ep_idx, cols[4]] = end_frames + + # Group by file and write + for ep_idx in episodes_df.index: + key = ( + episodes_df.loc[ep_idx, "meta/episodes/chunk_index"], + episodes_df.loc[ep_idx, "meta/episodes/file_index"], + ) + path = dataset_path / DEFAULT_EPISODES_PATH.format(chunk_index=key[0], file_index=key[1]) + if path.exists(): + file_df = pd.read_parquet(path) + for col in cols + ( + [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + if prefix == "sparse" + else [] + ): + if col not in file_df.columns: + file_df[col] = None + if ep_idx in annotations: + for col in cols: + file_df.at[ep_idx, col] = episodes_df.loc[ep_idx, col] + if prefix == "sparse": # Legacy columns + for i, legacy in enumerate( + [ + "subtask_names", + "subtask_start_times", + "subtask_end_times", + "subtask_start_frames", + "subtask_end_frames", + ] + ): + file_df.at[ep_idx, legacy] = episodes_df.loc[ep_idx, cols[i]] + file_df.to_parquet(path, engine="pyarrow", compression="snappy") + + +def generate_auto_sparse_annotations( + dataset: LeRobotDataset, episode_indices: list[int], video_key: str +) -> dict[int, SubtaskAnnotation]: + """Auto-generate single 'task' stage annotations for all episodes.""" + annotations = {} + for ep_idx in episode_indices: + start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + duration = end - start + end_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + annotations[ep_idx] = SubtaskAnnotation( + subtasks=[Subtask(name="task", timestamps=Timestamp(start="00:00", end=end_str))] + ) + return annotations + + +def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]: + """Load annotations from LeRobot dataset parquet files.""" + from lerobot.datasets.utils import load_episodes + + episodes_dataset = load_episodes(dataset_path) + if not episodes_dataset or len(episodes_dataset) == 0: + return {} + + col_names = f"{prefix}_subtask_names" + col_start = f"{prefix}_subtask_start_times" + col_end = f"{prefix}_subtask_end_times" + + # Fall back to legacy columns for sparse + if col_names not in episodes_dataset.column_names: + if prefix == "sparse" and "subtask_names" in episodes_dataset.column_names: + col_names, col_start, col_end = "subtask_names", "subtask_start_times", "subtask_end_times" + else: + return {} + + df = episodes_dataset.to_pandas() + annotations = {} + for ep_idx in df.index: + names = df.loc[ep_idx, col_names] + if names is None or (isinstance(names, float) and pd.isna(names)): + continue + starts, ends = df.loc[ep_idx, col_start], df.loc[ep_idx, col_end] + annotations[int(ep_idx)] = SubtaskAnnotation( + subtasks=[ + Subtask( + name=n, + timestamps=Timestamp( + start=f"{int(s) // 60:02d}:{int(s) % 60:02d}", + end=f"{int(e) // 60:02d}:{int(e) % 60:02d}", + ), + ) + for n, s, e in zip(names, starts, ends, strict=True) + ] + ) + return annotations + + +def process_single_episode( + ep_idx: int, + dataset_root: Path, + dataset_meta, + video_key: str, + fps: int, + annotator: VideoAnnotator, +) -> tuple[int, SubtaskAnnotation | None, str | None]: + """Process a single episode annotation.""" + try: + video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key) + if not video_path.exists(): + return ep_idx, None, f"Video not found: {video_path}" + + start = float(dataset_meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx]) + end = float(dataset_meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx]) + return ep_idx, annotator.annotate(video_path, fps, start, end), None + except Exception as e: + return ep_idx, None, str(e) + + +def worker_process_episodes( + worker_id: int, + gpu_id: int, + episode_indices: list[int], + repo_id: str, + video_key: str, + sparse_subtask_list: list[str], + dense_subtask_list: list[str] | None, + model_name: str, + torch_dtype: torch.dtype, +) -> tuple[dict, dict | None]: + """Worker for parallel processing across GPUs.""" + device = f"cuda:{gpu_id}" + dataset = LeRobotDataset(repo_id, download_videos=False) + + sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype) + dense_annotator = ( + VideoAnnotator( + dense_subtask_list, + model_name, + device, + torch_dtype, + sparse_annotator.model, + sparse_annotator.processor, + ) + if dense_subtask_list + else None + ) + + sparse_annotations, dense_annotations = {}, {} if dense_subtask_list else None + + for ep_idx in episode_indices: + _, sparse_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator + ) + if sparse_ann: + sparse_annotations[ep_idx] = sparse_ann + + if dense_annotator: + _, dense_ann, _ = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator + ) + if dense_ann: + dense_annotations[ep_idx] = dense_ann + + return sparse_annotations, dense_annotations + + +def main(): + parser = argparse.ArgumentParser(description="SARM-style subtask annotation using local GPU (Qwen3-VL)") + parser.add_argument("--repo-id", type=str, required=True, help="HuggingFace dataset repository ID") + parser.add_argument( + "--sparse-subtasks", type=str, default=None, help="Comma-separated sparse subtask names" + ) + parser.add_argument( + "--dense-subtasks", type=str, default=None, help="Comma-separated dense subtask names" + ) + parser.add_argument( + "--dense-only", action="store_true", help="Dense-only mode with auto-generated sparse 'task' stage" + ) + parser.add_argument("--episodes", type=int, nargs="+", default=None, help="Episode indices to annotate") + parser.add_argument("--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="VLM model") + parser.add_argument("--skip-existing", action="store_true", help="Skip already annotated episodes") + parser.add_argument("--video-key", type=str, default=None, help="Video key (default: first available)") + parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub") + parser.add_argument("--output-repo-id", type=str, default=None, help="Output repo ID for push") + parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"]) + parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU") + parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use") + # Visualization options + parser.add_argument( + "--visualize-only", + action="store_true", + help="Only visualize existing annotations (no generation)", + ) + parser.add_argument( + "--num-visualizations", + type=int, + default=5, + help="Number of episodes to visualize (default: 5)", + ) + parser.add_argument( + "--visualize-type", + type=str, + default="sparse", + choices=["sparse", "dense", "both"], + help="Type of annotations to visualize (default: sparse)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./subtask_viz", + help="Output directory for visualizations (default: ./subtask_viz)", + ) + + args = parser.parse_args() + + # Load dataset first (needed for both annotation and visualization) + print(f"Loading dataset: {args.repo_id}") + dataset = LeRobotDataset(args.repo_id, download_videos=True) + fps = dataset.fps + + if not dataset.meta.video_keys: + raise ValueError("No video keys found") + + video_key = ( + args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0] + ) + print(f"Using camera: {video_key}, FPS: {fps}") + + # Handle visualization-only mode + if args.visualize_only: + print("Visualization-only mode") + sparse_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse") + dense_annotations = load_annotations_from_dataset(dataset.root, prefix="dense") + + if not sparse_annotations and not dense_annotations: + return print("Error: No annotations found. Run annotation first.") + + print(f"Found {len(sparse_annotations)} sparse, {len(dense_annotations)} dense annotations") + + visualize_annotations( + dataset=dataset, + sparse_annotations=sparse_annotations, + dense_annotations=dense_annotations if dense_annotations else None, + video_key=video_key, + output_dir=Path(args.output_dir), + num_episodes=args.num_visualizations, + annotation_type=args.visualize_type, + episode_indices=args.episodes, + ) + return + + # Validate arguments for annotation mode + if args.dense_only and not args.dense_subtasks: + return print("Error: --dense-only requires --dense-subtasks") + if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only: + return print("Error: --dense-subtasks requires --sparse-subtasks or --dense-only") + + sparse_subtask_list = ( + [s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None + ) + dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None + auto_sparse = sparse_subtask_list is None + dense_mode = dense_subtask_list is not None + torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] + + # Determine episodes + episode_indices = args.episodes or list(range(dataset.meta.total_episodes)) + + existing_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse") + if args.skip_existing: + episode_indices = [ep for ep in episode_indices if ep not in existing_annotations] + + if not episode_indices: + return print("All episodes already annotated!") + print(f"Annotating {len(episode_indices)} episodes") + + # GPU setup + gpu_ids = args.gpu_ids or list( + range(min(args.num_workers, torch.cuda.device_count() if torch.cuda.is_available() else 1)) + ) + args.num_workers = len(gpu_ids) + + sparse_annotations = existing_annotations.copy() + dense_annotations = {} if dense_mode else None + + # Auto-sparse mode + if auto_sparse: + sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key)) + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + print(f"Auto-generated {len(episode_indices)} sparse 'task' annotations") + + # VLM annotation (for sparse if not auto, and for dense) + need_vlm = (not auto_sparse) or dense_mode + + if need_vlm: + if args.num_workers > 1 and not auto_sparse: + # Parallel processing + print(f"Parallel processing with {args.num_workers} workers") + episodes_per_worker = [[] for _ in range(args.num_workers)] + for i, ep_idx in enumerate(episode_indices): + episodes_per_worker[i % args.num_workers].append(ep_idx) + + with ProcessPoolExecutor( + max_workers=args.num_workers, mp_context=mp.get_context("spawn") + ) as executor: + futures = [ + executor.submit( + worker_process_episodes, + w, + gpu_ids[w], + episodes_per_worker[w], + args.repo_id, + video_key, + sparse_subtask_list, + dense_subtask_list, + args.model, + torch_dtype, + ) + for w in range(args.num_workers) + if episodes_per_worker[w] + ] + + for future in as_completed(futures): + try: + worker_sparse, worker_dense = future.result() + sparse_annotations.update(worker_sparse) + if dense_mode and worker_dense: + dense_annotations.update(worker_dense) + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + if dense_mode: + save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense") + except Exception as e: + raise RuntimeError(f"Worker failed: {e}") from e + else: + # Sequential processing + sparse_annotator = ( + VideoAnnotator(sparse_subtask_list, args.model, args.device, torch_dtype) + if not auto_sparse and sparse_subtask_list + else None + ) + dense_annotator = ( + VideoAnnotator( + dense_subtask_list, + args.model, + args.device, + torch_dtype, + sparse_annotator.model if sparse_annotator else None, + sparse_annotator.processor if sparse_annotator else None, + ) + if dense_mode + else None + ) + + for i, ep_idx in enumerate(episode_indices): + print(f"Episode {ep_idx} ({i + 1}/{len(episode_indices)})") + + if sparse_annotator: + _, sparse_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator + ) + if sparse_ann: + sparse_annotations[ep_idx] = sparse_ann + save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse") + elif err: + print(f"Sparse failed: {err}") + + if dense_annotator: + _, dense_ann, err = process_single_episode( + ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator + ) + if dense_ann: + dense_annotations[ep_idx] = dense_ann + save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense") + elif err: + print(f"Dense failed: {err}") + + # Save temporal proportions + def save_proportions(annotations, prefix, subtask_list=None, is_auto=False): + props: dict[str, float] = ( + {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps, subtask_list) + ) + path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json" + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(props, f, indent=2) + print(f"Saved {prefix} temporal proportions") + + save_proportions(sparse_annotations, "sparse", sparse_subtask_list, auto_sparse) + if dense_mode and dense_annotations: + save_proportions(dense_annotations, "dense", dense_subtask_list) + + print(f"\nComplete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations") + + # Visualize annotations after generation + if args.num_visualizations > 0: + print(f"\nGenerating {args.num_visualizations} visualizations...") + visualize_type = "both" if dense_mode else "sparse" + visualize_annotations( + dataset=dataset, + sparse_annotations=sparse_annotations, + dense_annotations=dense_annotations, + video_key=video_key, + output_dir=Path(args.output_dir), + num_episodes=args.num_visualizations, + annotation_type=visualize_type, + ) + + if args.push_to_hub: + try: + dataset.push_to_hub(push_videos=True) + print(f"Pushed to {args.output_repo_id or args.repo_id}") + except Exception as e: + print(f"Push failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 788542d49..ceefb0d56 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -29,6 +29,7 @@ __all__ = [ "PI0Config", "PI05Config", "SmolVLAConfig", + "SARMConfig", "TDMPCConfig", "VQBeTConfig", "GrootConfig", diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index b7cbcd061..a5c48eb3d 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy): def __init__( self, config: ACTConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 3ab6719cb..1fdc76f10 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy): def __init__( self, config: DiffusionConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3d17fa7dc..eb1ff41f7 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -37,6 +37,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import validate_visual_features_consistency @@ -105,6 +106,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "sarm": + from lerobot.policies.sarm.modeling_sarm import SARMRewardModel + + return SARMRewardModel elif name == "groot": from lerobot.policies.groot.modeling_groot import GrootPolicy @@ -337,6 +342,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, SARMConfig): + from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors + + processors = make_sarm_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), + ) elif isinstance(policy_cfg, GrootConfig): from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors @@ -435,6 +448,13 @@ def make_policy( cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} kwargs["config"] = cfg + # Pass dataset_stats to the policy if available (needed for some policies like SARM) + if ds_meta is not None and hasattr(ds_meta, "stats"): + kwargs["dataset_stats"] = ds_meta.stats + + if ds_meta is not None: + kwargs["dataset_meta"] = ds_meta + if cfg.pretrained_path: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 605f7a097..bdaef37b9 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy): name = "groot" config_class = GrootConfig - def __init__(self, config: GrootConfig): + def __init__(self, config: GrootConfig, **kwargs): """Initialize Groot policy wrapper.""" super().__init__(config) config.validate_features() diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 4b79c2902..0d9c77e00 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -907,6 +907,7 @@ class PI0Policy(PreTrainedPolicy): def __init__( self, config: PI0Config, + **kwargs, ): """ Args: @@ -1235,9 +1236,15 @@ class PI0Policy(PreTrainedPolicy): return actions - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Run the batch through the model and compute the loss for training.""" + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ # Prepare inputs images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] @@ -1251,11 +1258,17 @@ class PI0Policy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] - loss = losses.mean() - loss_dict = { - "loss": loss.item(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 64eb4cb23..2cd142042 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -880,6 +880,7 @@ class PI05Policy(PreTrainedPolicy): def __init__( self, config: PI05Config, + **kwargs, ): """ Args: @@ -1209,9 +1210,15 @@ class PI05Policy(PreTrainedPolicy): return actions - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Run the batch through the model and compute the loss for training.""" + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ # Prepare inputs images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] @@ -1225,11 +1232,17 @@ class PI05Policy(PreTrainedPolicy): original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] - loss = losses.mean() - loss_dict = { - "loss": loss.item(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py new file mode 100644 index 000000000..5b6ea6e9b --- /dev/null +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -0,0 +1,870 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting. + +This script processes all frames in a dataset with SARM to compute progress values [0, 1]. +The results are saved as a parquet file that can be loaded during training for RA-BC weighting. + +Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only +need ~num_frames/30 queries instead of one per frame (~30x speedup). + +Usage: + # Full RA-BC computation with visualizations + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 + + # Faster computation with stride (compute every 5 frames, interpolate the rest) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --stride 5 + + # Visualize predictions only (no RA-BC computation) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --visualize-only \\ + --num-visualizations 5 + +The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'. +""" + +import argparse +import logging +from pathlib import Path + +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from tqdm import tqdm + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.sarm.modeling_sarm import SARMRewardModel +from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors +from lerobot.policies.sarm.sarm_utils import normalize_stage_tau + + +def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None: + """Read reward_model_path from parquet metadata if available.""" + if not parquet_path.exists(): + return None + try: + metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata + if metadata and b"reward_model_path" in metadata: + return metadata[b"reward_model_path"].decode() + except Exception: # nosec B110 + return None + return None + + +def load_sarm_resources( + dataset_repo_id: str, + reward_model_path: str, + device: str = "cuda", +) -> tuple[LeRobotDataset, SARMRewardModel, any]: + """ + Load SARM model, dataset, and preprocessor. + + Returns: + Tuple of (dataset, reward_model, preprocessor) + """ + logging.info(f"Loading model: {reward_model_path}") + reward_model = SARMRewardModel.from_pretrained(reward_model_path) + reward_model.config.device = device + reward_model.to(device).eval() + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + delta_indices = reward_model.config.observation_delta_indices + + logging.info(f"Loading dataset: {dataset_repo_id}") + temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True) + fps = temp_dataset.fps + + delta_timestamps = { + image_key: [idx / fps for idx in delta_indices], + state_key: [idx / fps for idx in delta_indices], + } + dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps) + logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames") + + preprocess, _ = make_sarm_pre_post_processors( + config=reward_model.config, + dataset_stats=dataset.meta.stats, + dataset_meta=dataset.meta, + ) + + return dataset, reward_model, preprocess + + +def to_numpy_image(img) -> np.ndarray: + """Convert image tensor to numpy uint8 (H, W, C).""" + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.ndim == 4: + # Take center frame for bidirectional sampling + img = img[img.shape[0] // 2] + if img.shape[0] in [1, 3]: + img = np.transpose(img, (1, 2, 0)) + if img.dtype != np.uint8: + # Handle normalized images (may have negative values or values > 1) + img = img.astype(np.float32) + img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1] + img = (img * 255).astype(np.uint8) + return img + + +def visualize_episode( + frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None +): + """Create visualization with progress plot, stage probabilities, and sample frames. + + Same as sarm_inference_visualization.py + """ + num_stages = stage_preds.shape[1] + colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) + frame_indices = np.arange(len(progress_preds)) + + fig = plt.figure(figsize=(14, 12)) + gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3) + ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2]) + + # Progress plot + ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted") + ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB") + if gt_progress is not None: + ax_progress.plot( + frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth" + ) + ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5) + ax_progress.set_ylabel("Progress") + ax_progress.set_title(f'Task: "{title}"', fontweight="bold") + ax_progress.set_ylim(-0.05, 1.1) + ax_progress.legend(loc="upper left") + ax_progress.grid(True, alpha=0.3) + + # Stage predictions + ax_stages.stackplot( + frame_indices, + *[stage_preds[:, i] for i in range(num_stages)], + colors=colors, + alpha=0.8, + labels=stage_labels, + ) + if gt_stages is not None: + for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1: + ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5) + ax_stages.set_xlabel("Frame") + ax_stages.set_ylabel("Stage Probability") + ax_stages.set_ylim(0, 1) + ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8) + ax_stages.grid(True, alpha=0.3) + + # Sample frames + ax_frames.axis("off") + num_sample = 8 + sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int) + h, w = frames[0].shape[:2] + combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8) + for i, idx in enumerate(sample_indices): + frame = frames[idx] + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=-1) + combined[:, i * w : (i + 1) * w] = frame + stage_name = stage_labels[np.argmax(stage_preds[idx])][:12] + ax_frames.text( + i * w + w / 2, + -10, + f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}", + ha="center", + va="top", + fontsize=7, + ) + ax_frames.imshow(combined) + ax_frames.set_title("Sample Frames", pad=20) + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved: {output_path}") + + +def visualize_sarm_predictions( + dataset: LeRobotDataset, + reward_model: SARMRewardModel, + preprocess, + episode_indices: list[int], + head_mode: str, + output_dir: Path, + num_display_frames: int = 5, + stride: int = 1, +): + """ + Visualize SARM predictions for multiple episodes. + + Computes predictions for every frame by default. With stride > 1, computes predictions + every N frames and interpolates (progress + stage probabilities) for visualization. + + Args: + dataset: LeRobotDataset with delta_timestamps configured + reward_model: Loaded SARM model + preprocess: Preprocessor from make_sarm_pre_post_processors + episode_indices: List of episode indices to visualize + head_mode: "sparse", "dense", or "both" + output_dir: Directory to save visualizations + num_display_frames: Number of frames to display in thumbnail strip (default: 5) + stride: Compute predictions every N frames, interpolate the rest (default: 1) + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + dual_mode = reward_model.config.uses_dual_heads + device = reward_model.device + + # Center frame index for bidirectional sampling + target_idx = reward_model.config.n_obs_steps // 2 + + # Determine which heads to visualize + schemes_to_viz = [] + if head_mode in ("sparse", "both") or not dual_mode: + schemes_to_viz.append("sparse") + if head_mode in ("dense", "both") and dual_mode: + schemes_to_viz.append("dense") + + # Set preprocessor to eval mode to disable augmentations + if hasattr(preprocess, "eval"): + preprocess.eval() + for step in preprocess.steps: + if hasattr(step, "eval"): + step.eval() + + for episode_idx in episode_indices: + ep = dataset.meta.episodes[episode_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + task = dataset[ep_start].get("task", "perform the task") + num_frames = ep_end - ep_start + + # Select frames for display thumbnails (evenly sampled from begin to end) + display_indices = set( + [ + ep_start + int(i * (num_frames - 1) / (num_display_frames - 1)) + for i in range(num_display_frames) + ] + if num_frames >= num_display_frames + else list(range(ep_start, ep_end)) + ) + viz_frames = {} + + # Load display frames up-front (stride mode might skip them otherwise). + for frame_idx in display_indices: + sample = dataset[frame_idx] + viz_frames[frame_idx] = to_numpy_image(sample[image_key]) + + # Initialize storage for each scheme + scheme_data = {} + for scheme in schemes_to_viz: + num_stages = getattr(reward_model.config, f"num_{scheme}_stages") + scheme_data[scheme] = { + "viz_progress": np.full(num_frames, np.nan), + "viz_stages": np.full((num_frames, num_stages), np.nan), + "viz_gt_progress": np.full(num_frames, np.nan), + "viz_gt_stages": np.full(num_frames, np.nan), + "target_key": f"{scheme}_targets", + "num_stages": num_stages, + "temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"), + "subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"), + } + + if stride > 1: + logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating") + + # Process frames one at a time to avoid memory buildup + frame_indices = list(range(ep_start, ep_end, stride)) + if (ep_end - 1) not in frame_indices: + frame_indices.append(ep_end - 1) + frame_indices = sorted(set(frame_indices)) + + for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False): + local_idx = frame_idx - ep_start + sample = dataset[frame_idx] + + batch = { + image_key: sample[image_key], + "task": task, + "index": frame_idx, + "episode_index": episode_idx, + } + if state_key in sample: + batch[state_key] = sample[state_key] + + with torch.no_grad(): + processed = preprocess(batch) + video_features = processed["video_features"].to(device) + text_features = processed["text_features"].to(device) + state_features = processed.get("state_features") + if state_features is not None: + state_features = state_features.to(device) + lengths = processed.get("lengths") + + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + + # Ground truth + # In stride visualization mode, ground-truth plots can be misleading + # (only sparse points are available), so we skip GT. + if stride == 1 and sd["target_key"] in processed: + gt_target = processed[sd["target_key"]][0, target_idx].cpu().item() + sd["viz_gt_stages"][local_idx] = int(gt_target) + sd["viz_gt_progress"][local_idx] = normalize_stage_tau( + gt_target, + num_stages=sd["num_stages"], + temporal_proportions=sd["temporal_props"], + subtask_names=sd["subtask_names"], + ) + + # Predictions + reward, stage_probs = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + return_stages=True, + head_mode=scheme, + ) + + # Handle both tensor and numpy outputs + if isinstance(reward, torch.Tensor): + reward = reward.cpu().numpy() + stage_probs = stage_probs.cpu().numpy() + + if reward.ndim == 2: + sd["viz_progress"][local_idx] = reward[0, target_idx] + sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :] + else: + sd["viz_progress"][local_idx] = reward[target_idx] + sd["viz_stages"][local_idx] = stage_probs[target_idx, :] + + # Clear GPU memory after each frame + del processed, video_features, text_features + if state_features is not None: + del state_features + + torch.cuda.empty_cache() + + # Interpolate predictions back to per-frame arrays for smooth visualization. + if stride > 1: + all_local = np.arange(num_frames) + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + + valid = np.isfinite(sd["viz_progress"]) + valid_idx = np.where(valid)[0] + if valid_idx.size >= 1: + sd["viz_progress"] = interpolate_progress( + valid_idx, sd["viz_progress"][valid_idx], all_local + ) + + stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32) + for s in range(sd["num_stages"]): + stage_interp[:, s] = interpolate_progress( + valid_idx, sd["viz_stages"][valid_idx, s], all_local + ) + + stage_interp = np.clip(stage_interp, 0.0, 1.0) + row_sums = stage_interp.sum(axis=1, keepdims=True) + nz = row_sums.squeeze(-1) > 0 + stage_interp[nz] = stage_interp[nz] / row_sums[nz] + sd["viz_stages"] = stage_interp + else: + # No valid points: keep NaNs/zeros; visualization will be empty. + sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0) + + # Generate visualization for each head + ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)] + for scheme in schemes_to_viz: + sd = scheme_data[scheme] + stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])] + viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png" + + visualize_episode( + frames=np.array(ordered_viz_frames), + progress_preds=sd["viz_progress"], + stage_preds=sd["viz_stages"], + title=f"{task} (Episode {episode_idx})", + output_path=viz_path, + stage_labels=stage_labels, + gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None, + gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None, + ) + + # Clear memory between episodes + torch.cuda.empty_cache() + + logging.info(f"Visualizations saved to: {output_dir.absolute()}") + + +def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]: + """Generate all frame indices, ordered by offset for cache-friendly access. + + Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...] + This groups frames that share similar temporal windows together. + """ + num_frames = ep_end - ep_start + indices = [] + for offset in range(frame_gap): + for frame_rel in range(offset, num_frames, frame_gap): + indices.append(ep_start + frame_rel) + return indices + + +def interpolate_progress( + computed_indices: np.ndarray, + computed_values: np.ndarray, + all_indices: np.ndarray, +) -> np.ndarray: + """Linearly interpolate values to fill in gaps (robust to NaNs / edge cases).""" + computed_indices = np.asarray(computed_indices) + computed_values = np.asarray(computed_values) + all_indices = np.asarray(all_indices) + + mask = np.isfinite(computed_values) + if mask.sum() == 0: + return np.full(all_indices.shape, np.nan, dtype=np.float32) + if mask.sum() == 1: + return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32) + + out = np.interp(all_indices, computed_indices[mask], computed_values[mask]) + return out.astype(np.float32) + + +def compute_sarm_progress( + dataset_repo_id: str, + reward_model_path: str, + output_path: str | None = None, + head_mode: str = "sparse", + device: str = "cuda", + num_visualizations: int = 5, + output_dir: str = "./sarm_viz", + stride: int = 1, +): + """ + Compute SARM progress predictions for all frames in a dataset. + + Args: + dataset_repo_id: HuggingFace dataset repo ID or local path + reward_model_path: Path to pretrained SARM model + output_path: Path to save results. If None, saves to dataset's cache directory + head_mode: SARM head to use ("sparse", "dense", or "both") + device: Device to use for inference + num_visualizations: Number of episodes to visualize (0 to skip) + output_dir: Directory to save visualizations + stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame) + """ + dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device) + + # Set preprocessor to eval mode to disable augmentations + if hasattr(preprocess, "eval"): + preprocess.eval() + for step in preprocess.steps: + if hasattr(step, "eval"): + step.eval() + + image_key = reward_model.config.image_key + state_key = reward_model.config.state_key + frame_gap = reward_model.config.frame_gap + num_episodes = dataset.num_episodes + total_frames = dataset.num_frames + logging.info(f"Processing {total_frames} frames across {num_episodes} episodes") + + # Determine which heads to compute + dual_mode = reward_model.config.uses_dual_heads + compute_sparse = head_mode in ("sparse", "both") or not dual_mode + compute_dense = head_mode in ("dense", "both") and dual_mode + + # Storage arrays + all_indices = [] + all_episode_indices = [] + all_frame_indices = [] + all_progress_sparse = [] if compute_sparse else None + all_progress_dense = [] if compute_dense else None + + if stride > 1: + logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest") + + # Process all episodes + for episode_idx in tqdm(range(num_episodes), desc="Episodes"): + ep = dataset.meta.episodes[episode_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + + # Get task description + task = dataset[ep_start].get("task", "perform the task") + + # Generate frames to compute (with stride applied) + all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap) + if stride > 1: + # Only compute every stride-th frame (relative to episode start) + compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0] + # Always include last frame for better interpolation at episode end + last_frame = ep_end - 1 + if last_frame not in compute_indices: + compute_indices.append(last_frame) + compute_indices = sorted(set(compute_indices)) + else: + compute_indices = all_ep_indices + + center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window + + # Dictionary to collect results + frame_results = {} + + for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False): + try: + sample = dataset[query_idx] + + batch = { + image_key: sample[image_key], + "task": task, + "index": query_idx, + "episode_index": episode_idx, + } + if state_key in sample: + batch[state_key] = sample[state_key] + + with torch.no_grad(): + processed = preprocess(batch) + video_features = processed["video_features"].to(device) + text_features = processed["text_features"].to(device) + state_features = processed.get("state_features") + if state_features is not None: + state_features = state_features.to(device) + lengths = processed.get("lengths") + + sparse_val = np.nan + dense_val = np.nan + + # Compute sparse prediction for center frame + if compute_sparse: + sparse_progress = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + head_mode="sparse", + ) + sparse_val = float( + sparse_progress[0, center_idx] + if sparse_progress.ndim == 2 + else sparse_progress[center_idx] + ) + + # Compute dense prediction for center frame + if compute_dense: + dense_progress = reward_model.calculate_rewards( + text_embeddings=text_features, + video_embeddings=video_features, + state_features=state_features, + lengths=lengths, + return_all_frames=True, + head_mode="dense", + ) + dense_val = float( + dense_progress[0, center_idx] + if dense_progress.ndim == 2 + else dense_progress[center_idx] + ) + + frame_results[query_idx] = (sparse_val, dense_val) + + except Exception as e: + logging.warning(f"Failed to process frame {query_idx}: {e}") + + # Interpolate to get values for all frames + computed_indices = np.array(sorted(frame_results.keys())) + computed_sparse = ( + np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None + ) + computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None + + # All frame indices for this episode + all_frame_idx_array = np.arange(ep_start, ep_end) + + if stride > 1 and len(computed_indices) > 1: + # Interpolate progress values + if compute_sparse: + interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array) + if compute_dense: + interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array) + else: + # No interpolation needed + interp_sparse = computed_sparse if compute_sparse else None + interp_dense = computed_dense if compute_dense else None + + # Store results for all frames + for i, frame_idx in enumerate(all_frame_idx_array): + local_idx = frame_idx - ep_start + all_indices.append(frame_idx) + all_episode_indices.append(episode_idx) + all_frame_indices.append(local_idx) + if compute_sparse: + if stride > 1 and len(computed_indices) > 1: + all_progress_sparse.append(float(interp_sparse[i])) + elif frame_idx in frame_results: + all_progress_sparse.append(frame_results[frame_idx][0]) + else: + all_progress_sparse.append(np.nan) + if compute_dense: + if stride > 1 and len(computed_indices) > 1: + all_progress_dense.append(float(interp_dense[i])) + elif frame_idx in frame_results: + all_progress_dense.append(frame_results[frame_idx][1]) + else: + all_progress_dense.append(np.nan) + + # Create output table + table_data = { + "index": np.array(all_indices, dtype=np.int64), + "episode_index": np.array(all_episode_indices, dtype=np.int64), + "frame_index": np.array(all_frame_indices, dtype=np.int64), + } + if compute_sparse: + table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32) + if compute_dense: + table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32) + + # Sort by index + df = pa.table(table_data).to_pandas() + df = df.sort_values("index").reset_index(drop=True) + final_table = pa.Table.from_pandas(df, preserve_index=False) + + # Add metadata with reward model path + metadata = {b"reward_model_path": reward_model_path.encode()} + final_table = final_table.replace_schema_metadata(metadata) + + # Determine output path + output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path) + + # Save + output_path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(final_table, output_path) + logging.info(f"Saved {len(final_table)} frame progress values to {output_path}") + + # Print statistics + if "progress_sparse" in df.columns: + valid = df["progress_sparse"].dropna() + logging.info( + f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, " + f"min={valid.min():.4f}, max={valid.max():.4f}" + ) + + if "progress_dense" in df.columns: + valid = df["progress_dense"].dropna() + logging.info( + f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, " + f"min={valid.min():.4f}, max={valid.max():.4f}" + ) + + # Visualize episodes after processing + if num_visualizations > 0: + viz_episodes = list(range(min(num_visualizations, num_episodes))) + logging.info(f"Generating {len(viz_episodes)} visualizations...") + visualize_sarm_predictions( + dataset=dataset, + reward_model=reward_model, + preprocess=preprocess, + episode_indices=viz_episodes, + head_mode=head_mode, + output_dir=Path(output_dir), + stride=stride, + ) + + return output_path + + +def main(): + parser = argparse.ArgumentParser( + description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Full RA-BC computation with visualizations + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 + + # Visualize predictions only (no RA-BC computation) + python src/lerobot/policies/sarm/compute_rabc_weights.py \\ + --dataset-repo-id lerobot/aloha_sim_insertion_human \\ + --reward-model-path pepijn223/sarm_single_uni4 \\ + --visualize-only \\ + --num-visualizations 10 + """, + ) + parser.add_argument( + "--dataset-repo-id", + type=str, + required=True, + help="HuggingFace dataset repo ID or local path", + ) + parser.add_argument( + "--reward-model-path", + type=str, + default=None, + help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)", + ) + parser.add_argument( + "--output-path", + type=str, + default=None, + help="Output path for parquet. If not set, saves to dataset's cache directory", + ) + parser.add_argument( + "--head-mode", + type=str, + default="sparse", + choices=["sparse", "dense", "both"], + help="SARM head to use (default: sparse)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use (default: cuda)", + ) + # Visualization options + parser.add_argument( + "--visualize-only", + action="store_true", + help="Only visualize SARM predictions (no RA-BC computation)", + ) + parser.add_argument( + "--num-visualizations", + type=int, + default=5, + help="Number of episodes to visualize (default: 5, set to 0 to skip)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./sarm_viz", + help="Output directory for visualizations (default: ./sarm_viz)", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Upload progress file to the dataset repo on HuggingFace Hub", + default=True, + ) + parser.add_argument( + "--stride", + type=int, + default=1, + help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + + # Try to get reward_model_path from parquet metadata if not provided + reward_model_path = args.reward_model_path + if reward_model_path is None: + # Load dataset to find parquet path + temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False) + parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet" + reward_model_path = get_reward_model_path_from_parquet(parquet_path) + if reward_model_path: + logging.info(f"Using reward model from parquet metadata: {reward_model_path}") + else: + raise ValueError( + "--reward-model-path is required (no existing parquet with model metadata found)" + ) + + # Handle visualize-only mode + if args.visualize_only: + dataset, reward_model, preprocess = load_sarm_resources( + args.dataset_repo_id, reward_model_path, args.device + ) + logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes") + viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes))) + visualize_sarm_predictions( + dataset=dataset, + reward_model=reward_model, + preprocess=preprocess, + episode_indices=viz_episodes, + head_mode=args.head_mode, + output_dir=Path(args.output_dir), + stride=args.stride, + ) + print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}") + return + + # Full RABC computation (compute_sarm_progress loads model/dataset itself) + output_path = compute_sarm_progress( + dataset_repo_id=args.dataset_repo_id, + reward_model_path=reward_model_path, + output_path=args.output_path, + head_mode=args.head_mode, + device=args.device, + num_visualizations=args.num_visualizations, + output_dir=args.output_dir, + stride=args.stride, + ) + + print(f"\nSARM progress values saved to: {output_path}") + + # Upload to Hub if requested + if args.push_to_hub: + from huggingface_hub import HfApi + + api = HfApi() + hub_path = "sarm_progress.parquet" + + print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}") + api.upload_file( + path_or_fileobj=str(output_path), + path_in_repo=hub_path, + repo_id=args.dataset_repo_id, + repo_type="dataset", + ) + print( + f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}" + ) + + print("\nTo use in training, add to your config:") + print(" use_rabc: true") + print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}") + print(" rabc_head_mode: sparse # or dense") + else: + print("\nTo use in training, add to your config:") + print(" use_rabc: true") + print(f" rabc_progress_path: {output_path}") + print(" rabc_head_mode: sparse # or dense") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py new file mode 100644 index 000000000..59cb352d5 --- /dev/null +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation. +Paper: https://arxiv.org/abs/2509.25358 +""" + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("sarm") +@dataclass +class SARMConfig(PreTrainedConfig): + """Configuration class for SARM (Stage-Aware Reward Modeling). + + Supports three annotation modes: + + 1. single_stage (default): No annotations needed. Uses the episode's task description + as a single stage covering the entire episode. + + 2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated + single sparse "task" stage covering the full episode. The dense head learns detailed + subtask progression while sparse provides overall task completion. + + 3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained) + annotations from VLM. Both heads are trained on their respective annotations. + + The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions + are loaded/generated during model initialization. + """ + + annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual" + n_obs_steps: int = 8 # Number of observation history steps + frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second) + max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation + + # Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property) + # During training with rewind: [obs_frames] + [rewind_frames] + # During inference: [obs_frames] only + + # Architecture params + image_dim: int = 512 + text_dim: int = 512 + hidden_dim: int = 768 + num_heads: int = 12 + num_layers: int = 8 + max_state_dim: int = 32 + drop_n_last_frames: int = 1 + batch_size: int = 64 + clip_batch_size: int = 64 + dropout: float = 0.1 + stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations + + rewind_probability: float = 0.8 + language_perturbation_probability: float = 0.2 + + # Sparse annotations (high-level stages) + num_sparse_stages: int = 1 + sparse_subtask_names: list | None = None + sparse_temporal_proportions: list | None = None + + # Dense annotations (fine-grained stages) + num_dense_stages: int | None = None + dense_subtask_names: list | None = None + dense_temporal_proportions: list | None = None + + pretrained_model_path: str | None = None + device: str | None = None + image_key: str = "observation.images.top" # Key for image used from the dataset + state_key: str = "observation.state" + + # Populated by the processor (video_features, state_features, text_features) + input_features: dict = field(default_factory=lambda: {}) + + # Output features (updated in __post_init__) + output_features: dict = field( + default_factory=lambda: { + "stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD), + "progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD), + } + ) + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "LANGUAGE": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + ) + + def __post_init__(self): + super().__post_init__() + + if self.annotation_mode not in ["single_stage", "dense_only", "dual"]: + raise ValueError( + f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}" + ) + + if self.annotation_mode == "single_stage": + # Use task description as stage name, full episode as one stage + self.num_sparse_stages = 1 + self.sparse_subtask_names = ["task"] + self.sparse_temporal_proportions = [1.0] + self.num_dense_stages = None + self.dense_subtask_names = None + self.dense_temporal_proportions = None + + elif self.annotation_mode == "dense_only": + self.num_sparse_stages = 1 + self.sparse_subtask_names = ["task"] + self.sparse_temporal_proportions = [1.0] + + self.input_features = {} + self.output_features = {} + + if self.image_key: + self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL) + + self.input_features[self.state_key] = PolicyFeature( + shape=(self.max_state_dim,), + type=FeatureType.STATE, + ) + + # Update output features based on annotation_mode + if self.annotation_mode in ["dense_only", "dual"]: + self.output_features["sparse_stage"] = PolicyFeature( + shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD + ) + self.output_features["sparse_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + dense_stages = self.num_dense_stages or self.num_sparse_stages + self.output_features["dense_stage"] = PolicyFeature( + shape=(self.num_frames, dense_stages), type=FeatureType.REWARD + ) + self.output_features["dense_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + else: + self.output_features["sparse_stage"] = PolicyFeature( + shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD + ) + self.output_features["sparse_progress"] = PolicyFeature( + shape=(self.num_frames, 1), type=FeatureType.REWARD + ) + + if self.max_rewind_steps >= self.n_obs_steps: + raise ValueError( + f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})" + ) + if self.num_sparse_stages < 1: + raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}") + if ( + self.annotation_mode in ["dense_only", "dual"] + and self.num_dense_stages is not None + and self.num_dense_stages < 2 + ): + raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}") + + def get_optimizer_preset(self) -> AdamWConfig: + """Get default optimizer configuration for SARM training.""" + return AdamWConfig( + lr=5e-5, + weight_decay=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + """Get default learning rate scheduler configuration.""" + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=5e-5, + decay_lr=5e-6, + num_warmup_steps=500, + num_decay_steps=50000, + ) + + def validate_features(self) -> None: + pass + + @property + def uses_dual_heads(self) -> bool: + """Whether the model uses dual heads (dense_only or dual annotation modes).""" + return self.annotation_mode in ["dense_only", "dual"] + + @property + def num_frames(self) -> int: + """Total number of frames in sequence. + + For training: 1 + n_obs_steps + max_rewind_steps + The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)] + """ + return 1 + self.n_obs_steps + self.max_rewind_steps + + @property + def max_length(self) -> int: + return self.num_frames + + @property + def observation_delta_indices(self) -> list[int]: + """Bidirectional frame sampling centered on target frame. + + Example with n_obs_steps=8, gap=30: + Before: [-120, -90, -60, -30] (4 frames) + Current: [0] (1 frame) + After: [30, 60, 90, 120] (4 frames) + Total: 9 frames + """ + half_steps = self.n_obs_steps // 2 + + past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)] + obs_deltas = past_deltas + [0] + future_deltas + + # Rewind placeholders + rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)] + + return obs_deltas + rewind_deltas + + @property + def action_delta_indices(self) -> None: + """SARM is a reward model, not an action policy.""" + return None + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py new file mode 100644 index 000000000..a88b2ad64 --- /dev/null +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -0,0 +1,793 @@ +#!/usr/bin/env python + +# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation. + +Paper: https://arxiv.org/abs/2509.25358 + +- StageTransformer: Predicts stage classification (sparse/dense) +- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage +""" + +import json +import logging +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import ( + normalize_stage_tau, + pad_state_to_max_dim, +) + + +class StageTransformer(nn.Module): + """ + Stage classification transformer for SARM. + + Predicts which stage/subtask the current frame belongs to. + Supports both sparse (high-level) and dense (fine-grained) annotation schemes. + + Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D) + Output: stage logits (B, T, num_classes) + """ + + def __init__( + self, + d_model: int = 512, + vis_emb_dim: int = 512, + text_emb_dim: int = 512, + state_dim: int = 32, + n_layers: int = 6, + n_heads: int = 8, + dropout: float = 0.1, + num_cameras: int = 1, + num_classes_sparse: int = 4, + num_classes_dense: int = 8, + ): + super().__init__() + self.d_model = d_model + self.num_cameras = num_cameras + + # Projections + self.lang_proj = nn.Linear(text_emb_dim, d_model) + self.visual_proj = nn.Linear(vis_emb_dim, d_model) + self.state_proj = nn.Linear(state_dim, d_model) + + # Encoder + enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True) + self.transformer = nn.TransformerEncoder(enc_layer, n_layers) + + # Positional bias on first visual frame + self.first_pos = nn.Parameter(torch.zeros(1, d_model)) + + # Shared fusion MLP + # Fuses (num_cameras + 2) streams: cameras + lang + state + fused_in = d_model * (num_cameras + 2) + self.fusion_backbone = nn.Sequential( + nn.LayerNorm(fused_in), + nn.Linear(fused_in, d_model), + nn.ReLU(), + ) + + # Scheme-specific heads + self.heads = nn.ModuleDict( + { + "sparse": nn.Linear(d_model, num_classes_sparse), + "dense": nn.Linear(d_model, num_classes_dense), + } + ) + + def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803 + """ + Prepare language embeddings for fusion. + + Accepts lang_emb of shape: + - (B, text_emb_dim) -> broadcast across time + - (B, T, text_emb_dim) -> per-timestep (dense annotation mode) + + Returns: (B, 1, T, D) + """ + if lang_emb.dim() == 3: + # (B, T, E) -> (B, T, D) -> (B, 1, T, D) + lang_proj = self.lang_proj(lang_emb).unsqueeze(1) + else: + # (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D) + lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D) + return lang_proj + + def forward( + self, + img_seq: torch.Tensor, # (B, N, T, vis_emb_dim) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) - valid sequence lengths + scheme: str = "sparse", # "sparse" or "dense" + ) -> torch.Tensor: + """ + Forward pass for stage classification. + + Args: + img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras + lang_emb: Language embeddings (B, E) or (B, T, E) for dense + state: State features (B, T, state_dim) + lengths: Valid sequence lengths (B,) for masking + scheme: "sparse" or "dense" for head selection + + Returns: + Stage logits (B, T, num_classes) + """ + assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}." + + B, N, T, _ = img_seq.shape # noqa: N806 + D = self.d_model # noqa: N806 + device = img_seq.device + + # Project inputs + vis_proj = self.visual_proj(img_seq) # (B, N, T, D) + state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D) + lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D) + + # Concatenate streams + # cameras + lang + state -> (B, N+2, T, D) + x = torch.cat([vis_proj, lang_proj, state_proj], dim=1) + + # Add positional bias to first visual frame + x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos + + # Flatten to tokens for Transformer + x_tokens = x.view(B, (N + 2) * T, D) + L = x_tokens.size(1) # noqa: N806 + + # Create padding mask + base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T) + mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T) + + # Create causal mask + causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1) + + # Encode + h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True) + + # Reshape and fuse + h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D) + fused = self.fusion_backbone(h) # (B, T, D) + + # Scheme-specific logits + logits = self.heads[scheme](fused) # (B, T, num_classes) + return logits + + +class SubtaskTransformer(nn.Module): + """ + Subtask progress regression transformer for SARM. + + Predicts within-stage normalized progress (tau) conditioned on stage prior. + The stage prior is a one-hot encoding passed from StageTransformer predictions. + + Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D) + Output: tau predictions (B, T) in [0, 1] + """ + + def __init__( + self, + d_model: int = 512, + vis_emb_dim: int = 512, + text_emb_dim: int = 512, + state_dim: int = 32, + n_layers: int = 6, + n_heads: int = 8, + dropout: float = 0.1, + num_cameras: int = 1, + ): + super().__init__() + self.d_model = d_model + self.num_cameras = num_cameras + + # Projections + self.lang_proj = nn.Linear(text_emb_dim, d_model) + self.visual_proj = nn.Linear(vis_emb_dim, d_model) + self.state_proj = nn.Linear(state_dim, d_model) + + # Encoder + enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True) + self.transformer = nn.TransformerEncoder(enc, n_layers) + + # Learned bias on first visual frame + self.first_pos = nn.Parameter(torch.zeros(1, d_model)) + + # Shared fusion backbone + # Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb + fused_in = d_model * (num_cameras + 3) + self.fusion_backbone = nn.Sequential( + nn.LayerNorm(fused_in), + nn.Linear(fused_in, d_model), + nn.ReLU(), + ) + + # Scheme-specific regression heads + self.heads = nn.ModuleDict( + { + "sparse": nn.Linear(d_model, 1), + "dense": nn.Linear(d_model, 1), + } + ) + + def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803 + """ + Prepare language embeddings for fusion. + """ + if lang_emb.dim() == 3: + # (B, T, E) -> (B, T, D) -> (B, 1, T, D) + return self.lang_proj(lang_emb).unsqueeze(1) + else: + # (B, E) -> (B, 1, 1, D) -> (B, 1, T, D) + return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D) + + def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor: + """ + Deterministic projection of one-hot stage to d_model by pad/truncate. + + Args: + stage_prior: One-hot stage embedding (B, 1, T, C) + + Returns: + Projected stage embedding (B, 1, T, d_model) + """ + B, one, T, C = stage_prior.shape # noqa: N806 + D = self.d_model # noqa: N806 + if D == C: + return stage_prior + elif D > C: + pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype) + return torch.cat([stage_prior, pad], dim=-1) + else: + return stage_prior[..., :D] + + def forward( + self, + img_seq: torch.Tensor, # (B, N, T, vis_emb_dim) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) - valid sequence lengths + stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb + scheme: str = "sparse", # "sparse" or "dense" + ) -> torch.Tensor: + """ + Forward pass for subtask progress regression. + + Args: + img_seq: Image embeddings (B, N, T, vis_emb_dim) + lang_emb: Language embeddings (B, E) or (B, T, E) + state: State features (B, T, state_dim) + lengths: Valid sequence lengths (B,) for masking + stage_prior: One-hot stage prior (B, 1, T, num_classes) + scheme: "sparse" or "dense" for head selection + + Returns: + Tau predictions (B, T) in [0, 1] via sigmoid + """ + assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}." + + B, N, T, _ = img_seq.shape # noqa: N806 + D = self.d_model # noqa: N806 + device = img_seq.device + + # Project inputs + vis_proj = self.visual_proj(img_seq) # (B, N, T, D) + state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D) + lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D) + stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D) + + # Concatenate all streams + # cameras + lang + state + stage_emb -> (B, N+3, T, D) + x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1) + + # Add positional bias to first visual frame + x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos + + # Flatten to tokens + x_tokens = x.view(B, (N + 3) * T, D) + L = x_tokens.size(1) # noqa: N806 + + # Create padding mask + base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) + mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T) + + # Create causal mask + causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1) + + # Encode + h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True) + + # Reshape and fuse + h = h.view(B, N + 3, T, D) + h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D) + fused = self.fusion_backbone(h_flat) # (B, T, D) + + # Scheme-specific regression head -> sigmoid + r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T) + return r + + +def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor: + """ + Generate one-hot stage embeddings from targets. + + Args: + num_classes: Number of stage classes + targets: Target values (B, T) where integer part is stage index + + Returns: + One-hot stage embedding (B, 1, T, num_classes) + """ + # Integer part of float targets -> [0, C-1] + idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T) + C = num_classes # noqa: N806 + # Identity-lookup one-hot + stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C) + stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C) + return stage_onehot + + +class SARMRewardModel(PreTrainedPolicy): + """ + SARM Reward Model for stage-aware task completion rewards. + + Uses two separate transformer models: + - StageTransformer: Classifies which stage/subtask + - SubtaskTransformer: Predicts within-stage progress (tau) + + Training uses 75%/25% GT/predicted stage conditioning (teacher forcing). + """ + + name = "sarm" + config_class = SARMConfig + + def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None): + super().__init__(config, dataset_stats) + config.validate_features() + self.config = config + self.dataset_stats = dataset_stats + self.device = torch.device( + config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu" + ) + + # Load temporal proportions based on annotation_mode + if config.annotation_mode == "single_stage": + logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}") + elif dataset_meta is not None: + self._load_temporal_proportions(dataset_meta) + + # Create two separate models + self.stage_model = StageTransformer( + d_model=config.hidden_dim, + vis_emb_dim=config.image_dim, + text_emb_dim=config.text_dim, + state_dim=config.max_state_dim, + n_layers=config.num_layers, + n_heads=config.num_heads, + dropout=config.dropout, + num_cameras=1, # Single camera for now + num_classes_sparse=config.num_sparse_stages, + num_classes_dense=config.num_dense_stages or config.num_sparse_stages, + ) + + self.subtask_model = SubtaskTransformer( + d_model=config.hidden_dim, + vis_emb_dim=config.image_dim, + text_emb_dim=config.text_dim, + state_dim=config.max_state_dim, + n_layers=config.num_layers, + n_heads=config.num_heads, + dropout=config.dropout, + num_cameras=1, + ) + + self.stage_model.to(self.device) + self.subtask_model.to(self.device) + + # GT/predicted stage ratio for teacher forcing + self.gt_stage_ratio = 0.75 + + if config.uses_dual_heads: + logging.info( + f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, " + f"{config.num_dense_stages} dense stages" + ) + else: + logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages") + + logging.info(f"SARM initialized on {self.device}") + + def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]: + """Load temporal proportions from a JSON file (preserving order).""" + if not path.exists(): + raise ValueError( + f"{annotation_type.capitalize()} temporal proportions not found at {path}. " + f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations." + ) + with open(path) as f: + proportions_dict = json.load(f) + names = list(proportions_dict.keys()) + logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}") + logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}") + return names, [proportions_dict[name] for name in names] + + def _load_temporal_proportions(self, dataset_meta) -> None: + """Load temporal proportions based on annotation_mode.""" + meta_path = dataset_meta.root / "meta" + + if self.config.annotation_mode == "dual": + names, props = self._load_proportions_from_json( + meta_path / "temporal_proportions_sparse.json", "sparse" + ) + ( + self.config.num_sparse_stages, + self.config.sparse_subtask_names, + self.config.sparse_temporal_proportions, + ) = len(names), names, props + + if self.config.annotation_mode in ["dense_only", "dual"]: + names, props = self._load_proportions_from_json( + meta_path / "temporal_proportions_dense.json", "dense" + ) + ( + self.config.num_dense_stages, + self.config.dense_subtask_names, + self.config.dense_temporal_proportions, + ) = len(names), names, props + if self.config.annotation_mode == "dense_only": + logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}") + + def to(self, device): + """Override to method to ensure all components move together.""" + super().to(device) + self.device = device if isinstance(device, torch.device) else torch.device(device) + self.stage_model.to(device) + self.subtask_model.to(device) + return self + + @torch.no_grad() + def calculate_rewards( + self, + text_embeddings: np.ndarray | torch.Tensor, + video_embeddings: np.ndarray | torch.Tensor, + state_features: np.ndarray | torch.Tensor | None = None, + lengths: np.ndarray | torch.Tensor | None = None, + return_all_frames: bool = False, + return_stages: bool = False, + return_confidence: bool = False, + head_mode: str | None = "sparse", + frame_index: int | None = None, + ) -> np.ndarray | tuple: + """ + Calculate rewards for given text, video, and state representations. + + This is the canonical method for SARM reward computation, used for: + - Inference/visualization + - RA-BC weight computation + + Args: + text_embeddings: Encoded text representations (batch_size, 512) + video_embeddings: Encoded video representations (batch_size, num_frames, 512) + state_features: Joint state features (batch_size, num_frames, state_dim) + lengths: Valid sequence lengths (batch_size,) + return_all_frames: If True, return rewards for all frames + return_stages: If True, also return stage predictions + return_confidence: If True, also return stage confidence + head_mode: Which head to use ("sparse" or "dense") + frame_index: Index of the target frame to extract (default: n_obs_steps). + + Returns: + Rewards and optionally stage probs/confidence. + """ + if isinstance(text_embeddings, np.ndarray): + text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32) + if isinstance(video_embeddings, np.ndarray): + video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32) + if state_features is not None and isinstance(state_features, np.ndarray): + state_features = torch.tensor(state_features, dtype=torch.float32) + + # Handle single sample case + if text_embeddings.dim() == 1: + text_embeddings = text_embeddings.unsqueeze(0) + video_embeddings = video_embeddings.unsqueeze(0) + if state_features is not None: + state_features = state_features.unsqueeze(0) + single_sample = True + else: + single_sample = False + + batch_size = video_embeddings.shape[0] + seq_len = video_embeddings.shape[1] + + scheme = head_mode + + # Default lengths if not provided + if lengths is None: + lengths = torch.full((batch_size,), seq_len, dtype=torch.int32) + elif isinstance(lengths, np.ndarray): + lengths = torch.tensor(lengths, dtype=torch.int32) + + # Reshape video to (B, N, T, D) for multi-camera format + # Currently single camera: (B, T, D) -> (B, 1, T, D) + img_seq = video_embeddings.unsqueeze(1).to(self.device) + lang_emb = text_embeddings.to(self.device) + state = ( + state_features.to(self.device) + if state_features is not None + else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device) + ) + lens = lengths.to(self.device) + + # Pad state to max_state_dim + state = pad_state_to_max_dim(state, self.config.max_state_dim) + + # Get num_classes for this scheme + num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages + + # Run stage model + stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme) + stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes) + stage_idx = stage_probs.argmax(dim=-1) # (B, T) + stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T) + + # Create one-hot stage prior + stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C) + stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C) + + # Run subtask model + tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme) + + # Compute final reward: stage + tau + raw_reward = stage_idx.float() + tau_pred # (B, T) + + # Normalize to [0, 1] using temporal proportions for proper weighting + if scheme == "sparse": + normalized_reward = normalize_stage_tau( + raw_reward, + num_stages=num_classes, + temporal_proportions=self.config.sparse_temporal_proportions, + subtask_names=self.config.sparse_subtask_names, + ) + else: + normalized_reward = normalize_stage_tau( + raw_reward, + num_stages=num_classes, + temporal_proportions=self.config.dense_temporal_proportions, + subtask_names=self.config.dense_subtask_names, + ) + + # Default frame index is n_obs_steps (last observation frame) + if frame_index is None: + frame_index = self.config.n_obs_steps + + # Prepare outputs (batch mode or no smoothing) + if return_all_frames: + rewards = normalized_reward.cpu().numpy() + else: + rewards = normalized_reward[:, frame_index].cpu().numpy() + + if single_sample: + rewards = rewards[0] if not return_all_frames else rewards[0] + + outputs = [rewards] + if return_stages: + probs = stage_probs.cpu().numpy() + if single_sample: + probs = probs[0] + outputs.append(probs) + if return_confidence: + conf = stage_conf.cpu().numpy() + if single_sample: + conf = conf[0] + outputs.append(conf) + + return outputs[0] if len(outputs) == 1 else tuple(outputs) + + def train(self, mode: bool = True): + """Set training mode for both models.""" + super().train(mode) + self.stage_model.train(mode) + self.subtask_model.train(mode) + return self + + def eval(self): + """Set evaluation mode for both models.""" + return self.train(False) + + def parameters(self): + """Override to return trainable parameters from both models.""" + from itertools import chain + + return chain(self.stage_model.parameters(), self.subtask_model.parameters()) + + def get_optim_params(self): + """Override to return optimizer parameters from both models.""" + return self.parameters() + + def reset(self): + """Required by PreTrainedPolicy but not used for reward models.""" + pass + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for reward models.""" + raise NotImplementedError("SARM model does not predict action chunks") + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Required by PreTrainedPolicy but not used for SARM.""" + raise NotImplementedError("SARM model does not select actions") + + def _train_step( + self, + img_emb: torch.Tensor, # (B, N, T, D) + lang_emb: torch.Tensor, # (B, E) or (B, T, E) + state: torch.Tensor, # (B, T, state_dim) + lengths: torch.Tensor, # (B,) + targets: torch.Tensor, # (B, T) - format: stage.tau + scheme: str, + ) -> dict[str, torch.Tensor]: + """ + Single training step for one annotation scheme. + + Implements 75%/25% GT/predicted stage conditioning. + + Args: + img_emb: Image embeddings (B, N, T, D) + lang_emb: Language embeddings + state: State features + lengths: Valid sequence lengths + targets: Target values where floor=stage, remainder=tau + scheme: "sparse" or "dense" + + Returns: + Dict with stage_loss, subtask_loss, total_loss + """ + num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages + + # Ground truth: stage (integer) and tau (fractional) + # Clamp stage indices to valid range [0, num_classes-1] to handle edge cases + # where targets may exceed expected range (e.g., frames between subtasks) + gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T) + gt_tau = torch.remainder(targets, 1.0) # (B, T) + + # Run stage model + stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme) + + # 75%/25% GT/predicted stage conditioning + if random.random() < self.gt_stage_ratio: + # Mode 1: Use ground truth stage -> one-hot + stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C) + else: + # Mode 2: Use predicted stage argmax -> one-hot + stage_idx = stage_pred.argmax(dim=-1) # (B, T) + stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C) + stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C) + + # Run subtask model with stage prior + tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme) + + # Compute losses + stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean") + subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean") + + return { + "stage_loss": stage_loss, + "subtask_loss": subtask_loss, + "total_loss": stage_loss + subtask_loss, + } + + def forward(self, batch): + """ + Forward pass for SARM reward model training. + + Uses stage+tau target format where: + - Integer part = stage index + - Fractional part = within-stage progress (tau) + + Training uses 75%/25% GT/predicted stage conditioning. + + Args: + batch: Dictionary with 'observation' containing: + - 'video_features': (B, T, 512) pre-encoded video features + - 'text_features': (B, 512) or (B, T, 512) text features + - 'state_features': (B, T, state_dim) joint state features + - 'lengths': (B,) valid sequence lengths + - 'sparse_targets': (B, T) sparse targets (stage.tau format) + - 'dense_targets': (B, T) dense targets (optional, for dual mode) + + Returns: + Tuple of (total_loss, output_dict with loss components) + """ + observation = batch.get("observation", batch) + + # Extract features + video_features = observation["video_features"].to(self.device) + text_features = observation["text_features"].to(self.device) + state_features = observation.get("state_features") + if state_features is not None: + state_features = state_features.to(self.device) + + batch_size = video_features.shape[0] + seq_len = video_features.shape[1] + + # Get lengths (default to full sequence) + lengths = observation.get("lengths") + if lengths is None: + lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device) + else: + lengths = lengths.to(self.device) + + # Reshape video to (B, N, T, D) - single camera + img_emb = video_features.unsqueeze(1) + + # Pad state to max_state_dim + if state_features is None: + state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device) + else: + state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim) + + output_dict = {} + total_loss = torch.tensor(0.0, device=self.device) + + # Sparse training (always) + sparse_targets = observation.get("sparse_targets") + if sparse_targets is None: + # Try legacy format + sparse_targets = observation.get("targets") + if sparse_targets is None: + raise ValueError("sparse_targets (or targets) is required for SARM training") + sparse_targets = sparse_targets.to(self.device) + + sparse_result = self._train_step( + img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse" + ) + output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item() + output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item() + total_loss = total_loss + sparse_result["total_loss"] + + # Dense training (if dual mode) + if self.config.uses_dual_heads: + dense_targets = observation.get("dense_targets") + if dense_targets is not None: + dense_targets = dense_targets.to(self.device) + dense_result = self._train_step( + img_emb, text_features, state_features, lengths, dense_targets, scheme="dense" + ) + output_dict["dense_stage_loss"] = dense_result["stage_loss"].item() + output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item() + total_loss = total_loss + dense_result["total_loss"] + + output_dict["total_loss"] = total_loss.item() + return total_loss, output_dict + + +def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor: + """Compute cross-entropy loss for stage classification.""" + _, _, num_stages = stage_logits.shape + stage_logits_flat = stage_logits.reshape(-1, num_stages) + # Clamp target stage indices to valid range [0, num_stages-1] + target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1) + return F.cross_entropy(stage_logits_flat, target_stages_flat) diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py new file mode 100644 index 000000000..5c617282a --- /dev/null +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SARM Processor for encoding images/text and generating stage+tau targets.""" + +import random +from typing import Any + +import numpy as np +import pandas as pd +import torch +from faker import Faker +from PIL import Image +from transformers import CLIPModel, CLIPProcessor + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.sarm.configuration_sarm import SARMConfig +from lerobot.policies.sarm.sarm_utils import ( + apply_rewind_augmentation, + compute_absolute_indices, + find_stage_and_tau, + pad_state_to_max_dim, +) +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + RenameObservationsProcessorStep, +) +from lerobot.processor.converters import ( + from_tensor_to_numpy, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +class SARMEncodingProcessorStep(ProcessorStep): + """ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM.""" + + def __init__( + self, + config: SARMConfig, + image_key: str | None = None, + dataset_meta=None, + dataset_stats: dict | None = None, + ): + super().__init__() + self.config = config + self.image_key = image_key or config.image_key + self.dataset_meta = dataset_meta + self.dataset_stats = dataset_stats + self.annotation_mode = config.annotation_mode + + # Helper to create temporal proportions dict + def make_props_dict(names, props): + return dict(zip(names, props, strict=True)) if names and props else None + + # Sparse annotations (always needed) + self.sparse_temporal_proportions = make_props_dict( + config.sparse_subtask_names, config.sparse_temporal_proportions + ) + self.sparse_subtask_names = config.sparse_subtask_names + + # Dense annotations (only for dual mode) + self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None + self.dense_temporal_proportions = ( + make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions) + if config.uses_dual_heads + else None + ) + + self.device = torch.device( + self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu" + ) + + self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) + self.clip_model.to(self.device) + self.clip_model.eval() + + self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"] + self.fake = Faker() + + def _find_episode_for_frame(self, frame_idx: int) -> int: + """Find the episode index for a given frame index.""" + for ep_idx in range(len(self.dataset_meta.episodes)): + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + if ep_start <= frame_idx < ep_end: + return ep_idx + return 0 + + def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray: + """Get episode indices for each frame index.""" + if episode_index is None: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index))) + + # If single episode but multiple frames, compute episode for each frame + if len(episode_indices) == 1 and len(frame_indices) > 1: + return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices]) + + return episode_indices + + def _generate_perturbed_task(self) -> str: + """Generate a random perturbed task string for language perturbation.""" + num_words = random.randint(1, 5) + verb = random.choice(self.verbs) + phrase = " ".join([verb] + self.fake.words(nb=num_words)) + return phrase + + def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]: + """Get global subtask names and temporal proportions for an annotation type.""" + if annotation_type == "dense": + return self.dense_subtask_names, self.dense_temporal_proportions + return self.sparse_subtask_names, self.sparse_temporal_proportions + + def _load_episode_annotations( + self, + ep_idx: int, + episodes_df: pd.DataFrame | None, + annotation_type: str, + global_names: list[str], + ) -> tuple[list | None, list | None, list | None]: + """Load subtask annotations for an episode from DataFrame.""" + # Single-stage mode: (linear progress 0→1) + if episodes_df is None or len(global_names) == 1: + return None, None, None + + # Resolve column name with fallback + def col(suffix): + prefixed = f"{annotation_type}_{suffix}" + return prefixed if prefixed in episodes_df.columns else suffix + + col_names = col("subtask_names") + if col_names not in episodes_df.columns or ep_idx >= len(episodes_df): + return None, None, None + + subtask_names = episodes_df.loc[ep_idx, col_names] + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + return None, None, None + + return ( + subtask_names, + episodes_df.loc[ep_idx, col("subtask_start_frames")], + episodes_df.loc[ep_idx, col("subtask_end_frames")], + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Encode images, text, and normalize states in the transition. + + Implements SARM training data preparation: + - Applies language perturbation (20% probability) + - Applies rewind augmentation (80% probability) + - Generates stage+tau targets for all frames + - Outputs lengths tensor for valid sequence masking + """ + new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition) + observation = new_transition.get(TransitionKey.OBSERVATION) + comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + frame_index = comp_data.get("index") + episode_index = comp_data.get("episode_index") + + if frame_index is None: + raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA") + if episode_index is None: + raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA") + + frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) + episode_indices = self._get_episode_indices(frame_indices, episode_index) + + image = observation.get(self.image_key) + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + + # If 4D (T, C, H, W) from delta_timestamps, add batch dim + # If 3D (C, H, W) single frame, add batch and time dims + if image.ndim == 4: + image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W) + elif image.ndim == 3: + image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W) + + batch_size = image.shape[0] + total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders + n_obs_steps = self.config.n_obs_steps + max_rewind_steps = self.config.max_rewind_steps + n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current) + + # Rewind augmentation + rewind_steps = torch.zeros(batch_size, dtype=torch.int32) + apply_rewind = self.training and random.random() < self.config.rewind_probability + + if apply_rewind and self.dataset_meta is not None: + for b_idx, (ep_idx, frame_idx) in enumerate( + zip(episode_indices.tolist(), frame_indices.tolist(), strict=True) + ): + ep_idx, frame_idx = int(ep_idx), int(frame_idx) + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + + rewind_step, _ = apply_rewind_augmentation( + frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap + ) + rewind_steps[b_idx] = rewind_step + + # Compute valid lengths: n_obs_frames + rewind_steps + lengths = n_obs_frames + rewind_steps # (B,) + + # Apply rewind masking to images + # For frames beyond valid length, we mask with zeros (or copy last valid frame) + for b_idx in range(batch_size): + valid_len = lengths[b_idx].item() + if valid_len < total_frames: + image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length + + # Encode images with CLIP + video_features = self._encode_images_batch(image) + observation["video_features"] = video_features + + state_key = self.config.state_key + state_data = observation.get(state_key) + + if isinstance(state_data, torch.Tensor): + state_tensor = state_data.float() + else: + state_tensor = torch.tensor(state_data, dtype=torch.float32) + + if state_tensor.ndim == 2: + state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D) + elif state_tensor.ndim == 1: + state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D) + + # Apply same rewind masking to state + for b_idx in range(batch_size): + valid_len = lengths[b_idx].item() + if valid_len < state_tensor.shape[1]: + state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length + + observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) + + task = comp_data.get("task") + if isinstance(task, list): + task = task[0] if task else "" + + # Apply language perturbation during training (20% probability) + # When perturbed, targets will be zeroed to train model to output low values for irrelevant text + apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability + if apply_perturbation: + task = self._generate_perturbed_task() + + # Encode text with CLIP + observation["text_features"] = self._encode_text_clip(task, batch_size) + + # Store lengths for model + observation["lengths"] = lengths + + # When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss + if self.dataset_meta is not None: + episodes_df = None + if self.sparse_subtask_names != ["task"]: + episodes_df = self.dataset_meta.episodes.to_pandas() + + # Generate sparse targets + if self.sparse_temporal_proportions is not None: + if apply_perturbation: + # Zero targets when language is perturbed + sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + else: + sparse_targets = self._compute_batch_targets( + frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse" + ) + observation["sparse_targets"] = sparse_targets + + # Generate dense targets (for dual mode) + if self.config.uses_dual_heads and self.dense_temporal_proportions is not None: + if apply_perturbation: + # Zero targets when language is perturbed + dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + else: + dense_targets = self._compute_batch_targets( + frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense" + ) + observation["dense_targets"] = dense_targets + + new_transition[TransitionKey.OBSERVATION] = observation + return new_transition + + def _compute_batch_targets( + self, + frame_indices: np.ndarray, + episode_indices: np.ndarray, + lengths: torch.Tensor, + rewind_steps: torch.Tensor, + episodes_df: pd.DataFrame | None, + annotation_type: str, + ) -> torch.Tensor: + """Compute stage+tau targets for a batch of samples.""" + batch_size = len(frame_indices) + n_obs_steps = self.config.n_obs_steps + max_rewind_steps = self.config.max_rewind_steps + total_frames = 1 + n_obs_steps + max_rewind_steps + frame_gap = self.config.frame_gap + + global_names, temporal_props = self._get_annotation_config(annotation_type) + targets = torch.zeros(batch_size, total_frames, dtype=torch.float32) + + for b_idx in range(batch_size): + ep_idx = int(episode_indices[b_idx]) + frame_idx = int(frame_indices[b_idx]) + + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + ep_length = ep_end - ep_start + + subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations( + ep_idx, episodes_df, annotation_type, global_names + ) + + # Compute observation frame indices + obs_indices, _ = compute_absolute_indices( + frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap + ) + obs_indices = obs_indices.tolist() + + # Compute targets for observation frames + for t_idx, abs_idx in enumerate(obs_indices): + rel_frame = abs_idx - ep_start + targets[b_idx, t_idx] = find_stage_and_tau( + rel_frame, + ep_length, + subtask_names, + subtask_start_frames, + subtask_end_frames, + global_names, + temporal_props, + return_combined=True, + ) + + # Compute targets for rewind frames (if any) + rewind_step = rewind_steps[b_idx].item() + if rewind_step > 0: + _, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps, + max_rewind_steps, + frame_gap=frame_gap, + rewind_step=rewind_step, + ) + + for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]): + rel_frame = max(0, abs_idx - ep_start) + targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau( + rel_frame, + ep_length, + subtask_names, + subtask_start_frames, + subtask_end_frames, + global_names, + temporal_props, + return_combined=True, + ) + + return targets + + @property + def training(self) -> bool: + return getattr(self, "_training_mode", True) + + def train(self, mode: bool = True): + """Set training mode for augmentation decisions.""" + self._training_mode = mode + return self + + def eval(self): + """Set evaluation mode (disable augmentations).""" + return self.train(False) + + @torch.no_grad() + def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor: + """Encode a batch of images using CLIP. + + Args: + images: Batched images with shape: (B, T, C, H, W) + + Returns: + Encoded feature vectors with shape (B, T, 512) + """ + + batch_size, seq_length = images.shape[0], images.shape[1] + images = images.reshape(batch_size * seq_length, *images.shape[2:]) + + num_frames = images.shape[0] + images_list = [] + for i in range(num_frames): + img = images[i] + if img.shape[0] in [1, 3]: # Channel first (C, H, W) + img = img.transpose(1, 2, 0) + + # Handle single channel + if img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + if img.dtype != np.uint8: + img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8) + + images_list.append(Image.fromarray(img)) + + all_embeddings = [] + for i in range(0, num_frames, self.config.clip_batch_size): + batch_imgs = images_list[i : i + self.config.clip_batch_size] + + inputs = self.clip_processor(images=batch_imgs, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get image embeddings + embeddings = self.clip_model.get_image_features(**inputs).detach().cpu() + + # Handle single frame case + if embeddings.dim() == 1: + embeddings = embeddings.unsqueeze(0) + + all_embeddings.append(embeddings) + + all_embeddings = torch.cat(all_embeddings) # (B*T, 512) + all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) + + return all_embeddings + + @torch.no_grad() + def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor: + """Encode text using CLIP text encoder (per SARM paper A.4). + + Args: + text: Task description text to encode + batch_size: Batch size to replicate for + + Returns: + Encoded text features with shape (B, 512) + """ + inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() + text_embedding = text_embedding.expand(batch_size, -1) + + return text_embedding + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Add encoded features to the observation features.""" + features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim) + ) + features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.config.text_dim,) + ) + features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature( + type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim) + ) + return features + + +def make_sarm_pre_post_processors( + config: SARMConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta=None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Create pre-processor and post-processor pipelines for SARM.""" + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[ + AddBatchDimensionProcessorStep(), + RenameObservationsProcessorStep(rename_map={}), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + SARMEncodingProcessorStep( + config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ], + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=[DeviceProcessorStep(device="cpu")], + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py new file mode 100644 index 000000000..5b6955d38 --- /dev/null +++ b/src/lerobot/policies/sarm/sarm_utils.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 + + +def find_stage_and_tau( + current_frame: int, + episode_length: int, + subtask_names: list | None, + subtask_start_frames: list | None, + subtask_end_frames: list | None, + global_subtask_names: list, + temporal_proportions: dict, + return_combined: bool = False, +) -> tuple[int, float] | float: + """Find stage and within-stage progress (tau) for a frame. + + Args: + current_frame: Frame index relative to episode start + episode_length: Total frames in episode + subtask_names: Subtask names for this episode (None for single_stage) + subtask_start_frames: Subtask start frames + subtask_end_frames: Subtask end frames + global_subtask_names: Global list of all subtask names + temporal_proportions: Dict of temporal proportions + return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple + + Returns: + Float (stage.tau) if return_combined, else (stage_idx, tau) tuple + """ + stage_idx, tau = 0, 0.0 + num_stages = len(global_subtask_names) + + # Single-stage mode: linear progress from 0 to 1 + if num_stages == 1: + tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1))) + elif subtask_names is None: + pass # stage_idx=0, tau=0.0 + elif current_frame < subtask_start_frames[0]: + pass # Before first subtask: stage_idx=0, tau=0.0 + elif current_frame > subtask_end_frames[-1]: + stage_idx, tau = num_stages - 1, 0.999 # After last subtask + else: + # Find which subtask this frame belongs to + found = False + for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True): + if start <= current_frame <= end: + stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0 + tau = compute_tau(current_frame, start, end) + found = True + break + # Frame between subtasks - use previous subtask's end state + if not found: + for j in range(len(subtask_names) - 1): + if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]: + name = subtask_names[j] + stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j + tau = 1.0 + break + + if return_combined: + # Clamp to avoid overflow at end + if stage_idx >= num_stages - 1 and tau >= 1.0: + return num_stages - 1 + 0.999 + return stage_idx + tau + return stage_idx, tau + + +def compute_absolute_indices( + frame_idx: int, + ep_start: int, + ep_end: int, + n_obs_steps: int, + frame_gap: int = 30, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute absolute frame indices with clamping for bidirectional observation sequence. + + Bidirectional sampling centered on target frame: + - Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames) + - Current: [0] (1 frame) + - After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames) + - Total: n_obs_steps + 1 frames + + Out-of-bounds frames are clamped (duplicated from boundary). + + Args: + frame_idx: Target frame index (center frame of sequence) + ep_start: Episode start index + ep_end: Episode end index (exclusive) + n_obs_steps: Number of observation steps (must be even for symmetric sampling) + frame_gap: Gap between observation frames + + Returns: + Tuple of (indices, out_of_bounds_flags) + """ + half_steps = n_obs_steps // 2 + + # Bidirectional deltas: past + current + future + past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [frame_gap * i for i in range(1, half_steps + 1)] + delta_indices = past_deltas + [0] + future_deltas + + frames = [] + out_of_bounds = [] + + for delta in delta_indices: + target_idx = frame_idx + delta + # Clamp to episode bounds (duplicate boundary frames for out-of-bounds) + clamped_idx = max(ep_start, min(ep_end - 1, target_idx)) + frames.append(clamped_idx) + # Flag as out-of-bounds if clamping occurred + out_of_bounds.append(1 if target_idx != clamped_idx else 0) + + return torch.tensor(frames), torch.tensor(out_of_bounds) + + +def apply_rewind_augmentation( + frame_idx: int, + ep_start: int, + n_obs_steps: int, + max_rewind_steps: int, + frame_gap: int = 30, + rewind_step: int | None = None, +) -> tuple[int, list[int]]: + """ + Generate rewind frame indices for temporal augmentation. + + Rewind simulates going backwards through previously seen frames, + starting from before the earliest observation frame (for bidirectional sampling). + Appends reversed frames after the observation sequence. + + Args: + frame_idx: Target frame index (center of bidirectional observation window) + ep_start: Episode start index + n_obs_steps: Number of observation steps + max_rewind_steps: Maximum rewind steps + frame_gap: Gap between frames + rewind_step: If provided, use this exact rewind step (for deterministic behavior). + If None, sample randomly. + + Returns: + Tuple of (rewind_step, rewind_indices) + """ + # For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap + half_steps = n_obs_steps // 2 + earliest_obs_frame = frame_idx - half_steps * frame_gap + + # Required history: frames before earliest observation frame + if earliest_obs_frame <= ep_start: + return 0, [] # No history before observation window + + # Max valid rewind steps based on available history before earliest obs frame + available_history = earliest_obs_frame - ep_start + max_valid_step = available_history // frame_gap + max_rewind = min(max_rewind_steps, max(0, max_valid_step)) + + if max_rewind <= 0: + return 0, [] + + # Sample rewind steps if not provided + rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind) + + if rewind_step == 0: + return 0, [] + + # Generate rewind indices going backwards from earliest obs frame + # rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back + rewind_indices = [] + for i in range(1, rewind_step + 1): + idx = earliest_obs_frame - i * frame_gap + idx = max(ep_start, idx) # Clamp to episode start + rewind_indices.append(idx) + + return rewind_step, rewind_indices + + +def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float: + """Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks.""" + duration = subtask_end - subtask_start + if duration <= 0: + return 1.0 + return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0)) + + +def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor: + """Pad the state tensor's last dimension to max_state_dim with zeros.""" + current_dim = state.shape[-1] + if current_dim >= max_state_dim: + return state[..., :max_state_dim] # Truncate if larger + + # Pad with zeros on the right + padding = (0, max_state_dim - current_dim) # (left, right) for last dim + return F.pad(state, padding, mode="constant", value=0) + + +def temporal_proportions_to_breakpoints( + temporal_proportions: dict[str, float] | list[float] | None, + subtask_names: list[str] | None = None, +) -> list[float] | None: + """Convert temporal proportions to cumulative breakpoints for normalization.""" + if temporal_proportions is None: + return None + + if isinstance(temporal_proportions, dict): + if subtask_names is not None: + proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names] + else: + proportions = list(temporal_proportions.values()) + else: + proportions = list(temporal_proportions) + + total = sum(proportions) + if total > 0 and abs(total - 1.0) > 1e-6: + proportions = [p / total for p in proportions] + + breakpoints = [0.0] + cumsum = 0.0 + for prop in proportions: + cumsum += prop + breakpoints.append(cumsum) + breakpoints[-1] = 1.0 + + return breakpoints + + +def normalize_stage_tau( + x: float | torch.Tensor, + num_stages: int | None = None, + breakpoints: list[float] | None = None, + temporal_proportions: dict[str, float] | list[float] | None = None, + subtask_names: list[str] | None = None, +) -> float | torch.Tensor: + """ + Normalize stage+tau reward to [0, 1] with custom breakpoints. + + Maps stage index + within-stage tau to normalized progress [0, 1]. + The breakpoints are designed to give appropriate weight to each stage + based on their importance in the task (using temporal proportions). + + Priority: breakpoints > temporal_proportions > linear fallback + + Args: + x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1) + num_stages: Number of stages (required if breakpoints/proportions not provided) + breakpoints: Optional custom breakpoints list of length num_stages + 1. + temporal_proportions: Optional temporal proportions dict/list to compute breakpoints. + subtask_names: Optional ordered list of subtask names (for dict proportions) + + Returns: + Normalized progress value ∈ [0, 1] + """ + if breakpoints is not None: + num_stages = len(breakpoints) - 1 + elif temporal_proportions is not None: + breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names) + num_stages = len(breakpoints) - 1 + elif num_stages is not None: + breakpoints = [i / num_stages for i in range(num_stages + 1)] + else: + raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided") + + if isinstance(x, torch.Tensor): + result = torch.zeros_like(x) + for i in range(num_stages): + mask = (x >= i) & (x < i + 1) + tau_in_stage = x - i + result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i]) + result[x >= num_stages] = 1.0 + return result.clamp(0.0, 1.0) + else: + if x < 0: + return 0.0 + if x >= num_stages: + return 1.0 + stage = int(x) + tau = x - stage + return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage]) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 485d3e4e5..f998661f9 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy): def __init__( self, config: SmolVLAConfig, + **kwargs, ): """ Args: @@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy): def _rtc_enabled(self) -> bool: return self.config.rtc_config is not None and self.config.rtc_config.enabled - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss""" + def forward( + self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean" + ) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss. + + Args: + batch: Training batch containing observations and actions. + noise: Optional noise tensor for flow matching. + time: Optional time tensor for flow matching. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) @@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy): losses = losses[:, :, : self.config.max_action_dim] loss_dict["losses_after_rm_padding"] = losses.clone() - # For backward pass - loss = losses.mean() - # For backward pass - loss_dict["loss"] = loss.item() - return loss, loss_dict + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict def prepare_images(self, batch): """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 195cf6154..f83c82e21 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy): def __init__( self, config: TDMPCConfig, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index c4ca35b72..bfbe2bf1d 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -231,11 +231,20 @@ def validate_visual_features_consistency( """ Validates visual feature consistency between a policy config and provided dataset/environment features. + Validation passes if EITHER: + - Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more) + - Dataset's provided visuals are a subset of policy (policy declares extras for flexibility) + Args: cfg (PreTrainedConfig): The model or policy configuration containing input_features and type. features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects. """ expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL} provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL} - if not provided_visuals.issubset(expected_visuals): + + # Accept if either direction is a subset + policy_subset_of_dataset = expected_visuals.issubset(provided_visuals) + dataset_subset_of_policy = provided_visuals.issubset(expected_visuals) + + if not (policy_subset_of_dataset or dataset_subset_of_policy): raise_feature_mismatch_error(provided_visuals, expected_visuals) diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 91d609701..359b4fdb1 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy): def __init__( self, config: VQBeTConfig | None = None, + **kwargs, ): """ Args: diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 27c7c6e1b..0436ae527 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy): config_class = XVLAConfig name = "xvla" - def __init__(self, config: XVLAConfig): + def __init__(self, config: XVLAConfig, **kwargs): super().__init__(config) config.validate_features() florence_config = config.get_florence_config() diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 6b0b67598..126be0e36 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: task_key = {"task": batch["task"]} if "task" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key} + return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 1ebdee600..6cf733442 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -62,6 +62,7 @@ def update_policy( accelerator: Accelerator, lr_scheduler=None, lock=None, + rabc_weights_provider=None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -78,6 +79,7 @@ def update_policy( accelerator: The Accelerator instance for distributed training and mixed precision. lr_scheduler: An optional learning rate scheduler. lock: An optional lock for thread-safe optimizer updates. + rabc_weights_provider: Optional RABCWeights instance for sample weighting. Returns: A tuple containing: @@ -87,9 +89,30 @@ def update_policy( start_time = time.perf_counter() policy.train() + # Get RA-BC weights if enabled + rabc_batch_weights = None + rabc_batch_stats = None + if rabc_weights_provider is not None: + rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) + # Let accelerator handle mixed precision with accelerator.autocast(): - loss, output_dict = policy.forward(batch) + # Use per-sample loss when RA-BC is enabled for proper weighting + if rabc_batch_weights is not None: + # Get per-sample losses + per_sample_loss, output_dict = policy.forward(batch, reduction="none") + + # Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε) + # rabc_batch_weights is already normalized to sum to batch_size + epsilon = 1e-6 + loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon) + # Log raw mean weight (before normalization) - this is the meaningful metric + output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"] + output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"] + output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"] + else: + loss, output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) # Use accelerator's backward method @@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg: A `TrainPipelineConfig` object containing all training configurations. accelerator: Optional Accelerator instance. If None, one will be created automatically. """ - cfg.validate() - # Create Accelerator if not provided # It will automatically detect if running in distributed mode or single-process mode # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes @@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # When using accelerate, only the main process should log to avoid duplicate outputs is_main_process = accelerator.is_main_process + cfg.validate() + # Only log on main process if is_main_process: logging.info(pformat(cfg.to_dict())) @@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats + # For SARM, always provide dataset_meta for progress normalization + if cfg.policy.type == "sarm": + processor_kwargs["dataset_meta"] = dataset.meta + if cfg.policy.pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, @@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) + # Load precomputed SARM progress for RA-BC if enabled + # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py + rabc_weights = None + if cfg.use_rabc: + from lerobot.utils.rabc import RABCWeights + + # Get chunk_size from policy config + chunk_size = getattr(policy.config, "chunk_size", None) + if chunk_size is None: + raise ValueError("Chunk size is not found in policy config") + + head_mode = getattr(cfg, "rabc_head_mode", "sparse") + logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}") + logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}") + rabc_weights = RABCWeights( + progress_path=cfg.rabc_progress_path, + chunk_size=chunk_size, + head_mode=head_mode, + kappa=getattr(cfg, "rabc_kappa", 0.01), + epsilon=getattr(cfg, "rabc_epsilon", 1e-6), + device=device, + ) + step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: @@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) if is_main_process: - logging.info("Start offline training on a fixed dataset") + logging.info( + f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" + ) for _ in range(step, cfg.steps): start_time = time.perf_counter() @@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg.optimizer.grad_clip_norm, accelerator=accelerator, lr_scheduler=lr_scheduler, + rabc_weights_provider=rabc_weights, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) + # Log RA-BC statistics if enabled + if rabc_weights is not None: + rabc_stats = rabc_weights.get_stats() + wandb_log_dict.update( + { + "rabc_delta_mean": rabc_stats["delta_mean"], + "rabc_delta_std": rabc_stats["delta_std"], + "rabc_num_frames": rabc_stats["num_frames"], + } + ) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py new file mode 100644 index 000000000..c529f3ccc --- /dev/null +++ b/src/lerobot/utils/rabc.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + + +class RABCWeights: + """ + Load precomputed SARM progress values and compute RA-BC weights during training. + + Progress values are loaded from a parquet file (generated by compute_rabc_weights.py). + During training, computes: + - progress_delta = progress[t + chunk_size] - progress[t] + - rabc_weight based on the delta (paper Eq. 8-9) + + Args: + progress_path: Path to parquet file with precomputed progress values + chunk_size: Number of frames ahead for computing progress delta + head_mode: Which SARM head to use ("sparse" or "dense") + kappa: Hard threshold for high-quality samples (default: 0.01) + epsilon: Small constant for numerical stability (default: 1e-6) + fallback_weight: Weight to use for frames without valid delta (default: 1.0) + device: Device to return tensors on + """ + + def __init__( + self, + progress_path: str | Path, + chunk_size: int = 50, + head_mode: str = "sparse", + kappa: float = 0.01, + epsilon: float = 1e-6, + fallback_weight: float = 1.0, + device: torch.device = None, + ): + self.progress_path = Path(progress_path) + self.chunk_size = chunk_size + self.head_mode = head_mode + self.kappa = kappa + self.epsilon = epsilon + self.fallback_weight = fallback_weight + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Determine progress column name + self.progress_column = f"progress_{head_mode}" + + # Load progress values + logging.info(f"Loading SARM progress values from {self.progress_path}") + self.df = pd.read_parquet(self.progress_path) + + # Check if the requested head mode column exists + if self.progress_column not in self.df.columns: + available = [c for c in self.df.columns if c.startswith("progress")] + raise ValueError( + f"Column '{self.progress_column}' not found. Available progress columns: {available}" + ) + + logging.info(f"Using progress column: {self.progress_column}") + + self.progress_lookup = {} + self.episode_lookup = {} + + for _, row in self.df.iterrows(): + global_idx = int(row["index"]) + progress = row[self.progress_column] + episode_idx = int(row["episode_index"]) + + if not np.isnan(progress): + self.progress_lookup[global_idx] = float(progress) + self.episode_lookup[global_idx] = episode_idx + + # Build episode boundaries for delta computation + self.episode_boundaries = {} + for episode_idx in self.df["episode_index"].unique(): + ep_df = self.df[self.df["episode_index"] == episode_idx] + self.episode_boundaries[int(episode_idx)] = { + "start": int(ep_df["index"].min()), + "end": int(ep_df["index"].max()) + 1, + } + + logging.info(f"Loaded {len(self.progress_lookup)} frame progress values") + logging.info(f"Chunk size for delta computation: {chunk_size}") + + # Compute global statistics for weight computation + self._compute_global_stats() + + def _compute_global_stats(self): + """Compute global mean and std of progress deltas for weight calculation.""" + all_deltas = [] + + for global_idx, progress in self.progress_lookup.items(): + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + continue + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + continue + + future_idx = global_idx + self.chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is not None: + delta = future_progress - progress + all_deltas.append(delta) + + if all_deltas: + self.delta_mean = max(np.mean(all_deltas), 0.0) + self.delta_std = max(np.std(all_deltas), self.epsilon) + logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}") + else: + self.delta_mean = 0.0 + self.delta_std = self.epsilon + logging.warning("No valid progress deltas found, using default stats") + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """ + Compute RA-BC weights for a batch. + + For each sample: + 1. Get progress at current frame + 2. Get progress at frame + chunk_size (within same episode) + 3. Compute delta = future_progress - current_progress + 4. Compute weight using paper Eq. 8-9 + + Args: + batch: Training batch containing "index" key with global frame indices + + Returns: + Tuple of: + - Weights tensor (batch_size,) normalized to sum to batch_size + - Stats dict with raw_mean_weight, num_zero_weight, num_full_weight + """ + indices = batch.get("index") + if indices is None: + logging.warning("RA-BC: Batch missing 'index' key, using uniform weights") + batch_size = self._get_batch_size(batch) + return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0} + + # Convert to list of ints + if isinstance(indices, torch.Tensor): + indices = indices.cpu().numpy().tolist() + elif isinstance(indices, np.ndarray): + indices = indices.tolist() + + # Compute deltas and weights for each sample + deltas = [] + for idx in indices: + idx = int(idx) + delta = self._compute_delta(idx) + deltas.append(delta) + + deltas = np.array(deltas, dtype=np.float32) + + # Compute weights from deltas + weights = self._compute_weights(deltas) + + # Compute stats before normalization for logging + raw_mean_weight = float(np.nanmean(weights)) + num_zero_weight = int(np.sum(weights == 0)) + num_full_weight = int(np.sum(weights == 1.0)) + batch_stats = { + "raw_mean_weight": raw_mean_weight, + "num_zero_weight": num_zero_weight, + "num_full_weight": num_full_weight, + } + + weights = torch.tensor(weights, device=self.device, dtype=torch.float32) + + # Normalize to sum to batch_size + batch_size = len(weights) + weight_sum = weights.sum() + self.epsilon + weights = weights * batch_size / weight_sum + + return weights, batch_stats + + def _compute_delta(self, global_idx: int) -> float: + """Compute progress delta for a single frame.""" + current_progress = self.progress_lookup.get(global_idx) + if current_progress is None: + return np.nan + + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + return np.nan + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + return np.nan + + future_idx = global_idx + self.chunk_size # Δ = chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress instead + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is None: + return np.nan + + return future_progress - current_progress + + def _compute_weights(self, deltas: np.ndarray) -> np.ndarray: + """ + Compute RA-BC weights from progress deltas. + + Following paper Eq. 8-9: + - Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1) + - Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi + + Returns: + Array of weights + """ + valid_mask = ~np.isnan(deltas) + + # Compute soft weights using global statistics + lower_bound = self.delta_mean - 2 * self.delta_std + soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon) + soft_weights = np.clip(soft_weights, 0.0, 1.0) + + # Apply paper's Eq. 9 + weights = np.zeros_like(deltas, dtype=np.float32) + + # High quality: ri > kappa → weight = 1 + high_quality_mask = deltas > self.kappa + weights[high_quality_mask] = 1.0 + + # Moderate quality: 0 <= ri <= kappa → weight = soft_weight + moderate_mask = (deltas >= 0) & (deltas <= self.kappa) + weights[moderate_mask] = soft_weights[moderate_mask] + + # Negative progress: ri < 0 → weight = 0 (already 0) + # Invalid (NaN): use fallback weight + weights[~valid_mask] = self.fallback_weight + + return weights + + def _get_batch_size(self, batch: dict) -> int: + """Determine batch size from batch.""" + for key in ["action", "index"]: + if key in batch: + val = batch[key] + if isinstance(val, (torch.Tensor, np.ndarray)): + return val.shape[0] + return 1 + + def get_stats(self) -> dict: + """Get statistics.""" + return { + "num_frames": len(self.progress_lookup), + "chunk_size": self.chunk_size, + "head_mode": self.head_mode, + "delta_mean": self.delta_mean, + "delta_std": self.delta_std, + "kappa": self.kappa, + } diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py new file mode 100644 index 000000000..66404f663 --- /dev/null +++ b/tests/policies/test_sarm_processor.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("faker") + +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest +import torch + +from lerobot.processor.core import TransitionKey + + +class MockDatasetMeta: + """Mock dataset metadata for testing processor.""" + + def __init__(self, episodes: list[dict]): + self._episodes = episodes + + @property + def episodes(self): + """Return episodes as a mock object with to_pandas() method.""" + mock = MagicMock() + mock.__len__ = lambda s: len(self._episodes) + mock.__getitem__ = lambda s, idx: self._episodes[idx] + mock.to_pandas = lambda: pd.DataFrame(self._episodes) + return mock + + +class MockConfig: + """Mock SARMConfig for testing processor methods.""" + + def __init__( + self, + n_obs_steps: int = 8, + max_rewind_steps: int = 4, + frame_gap: int = 30, + sparse_subtask_names: list = None, + sparse_temporal_proportions: list = None, + dense_subtask_names: list = None, + dense_temporal_proportions: list = None, + image_key: str = "observation.images.top", + state_key: str = "observation.state", + max_state_dim: int = 32, + device: str = None, + rewind_probability: float = 0.8, + language_perturbation_probability: float = 0.2, + annotation_mode: str = "dual", + clip_batch_size: int = 64, + text_dim: int = 512, + ): + self.n_obs_steps = n_obs_steps + self.max_rewind_steps = max_rewind_steps + self.frame_gap = frame_gap + self.sparse_subtask_names = sparse_subtask_names or ["task"] + self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0] + self.dense_subtask_names = dense_subtask_names + self.dense_temporal_proportions = dense_temporal_proportions + self.uses_dual_heads = annotation_mode in ["dense_only", "dual"] + self.image_key = image_key + self.state_key = state_key + self.max_state_dim = max_state_dim + self.device = device + self.rewind_probability = rewind_probability + self.language_perturbation_probability = language_perturbation_probability + self.annotation_mode = annotation_mode + self.clip_batch_size = clip_batch_size + self.text_dim = text_dim + + # Compute observation delta indices (same as config: bidirectional) + half_steps = self.n_obs_steps // 2 + past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)] + future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)] + obs_deltas = past_deltas + [0] + future_deltas + rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)] + self.observation_delta_indices = obs_deltas + rewind_deltas + + @property + def num_frames(self) -> int: + return 1 + self.n_obs_steps + self.max_rewind_steps + + +class TestSARMEncodingProcessorStepEndToEnd: + """End-to-end test for SARMEncodingProcessorStep with dummy batch data.""" + + @pytest.fixture + def mock_clip_model(self): + """Mock CLIP model to avoid loading real weights.""" + with ( + patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls, + patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls, + ): + # Mock the CLIP model - return embeddings based on input batch size + mock_model = MagicMock() + + def get_image_features_side_effect(**kwargs): + pixel_values = kwargs.get("pixel_values") + batch_size = pixel_values.shape[0] if pixel_values is not None else 1 + return torch.randn(batch_size, 512) + + mock_model.get_image_features.side_effect = get_image_features_side_effect + mock_model.get_text_features.return_value = torch.randn(1, 512) + mock_model.to.return_value = mock_model + mock_model_cls.from_pretrained.return_value = mock_model + + # Mock the CLIP processor - return tensors based on input images + mock_processor = MagicMock() + + def processor_side_effect(images=None, **kwargs): + num_images = len(images) if images is not None else 1 + return { + "pixel_values": torch.randn(num_images, 3, 224, 224), + } + + mock_processor.side_effect = processor_side_effect + # Mock tokenizer for text encoding + mock_processor.tokenizer.return_value = { + "input_ids": torch.ones(1, 77, dtype=torch.long), + "attention_mask": torch.ones(1, 77, dtype=torch.long), + } + mock_processor_cls.from_pretrained.return_value = mock_processor + + yield mock_model, mock_processor + + @pytest.fixture + def processor_with_mocks(self, mock_clip_model): + """Create a processor with mocked CLIP and dataset metadata for dual mode.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + # Dual mode config with both sparse and dense annotations + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, # Disable for deterministic test + language_perturbation_probability=0.0, # Disable for deterministic test + annotation_mode="dual", + sparse_subtask_names=["reach", "grasp", "lift"], + sparse_temporal_proportions=[0.3, 0.4, 0.3], + dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + # Create mock dataset metadata with one episode of 300 frames + # Include annotation columns for dual mode + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 300, + "task": "pick up the cube", + "sparse_subtask_names": ["reach", "grasp", "lift"], + "sparse_subtask_start_frames": [0, 90, 210], + "sparse_subtask_end_frames": [90, 210, 300], + "dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"], + "dense_subtask_start_frames": [0, 75, 150, 225], + "dense_subtask_end_frames": [75, 150, 225, 300], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep( + config=config, + dataset_meta=dataset_meta, + ) + processor.train(True) # Use train() method, not direct assignment + + return processor, config + + def test_call_with_single_frame_batch(self, processor_with_mocks): + """Test processor __call__ with a single-frame batch.""" + processor, config = processor_with_mocks + + # Create dummy input transition + batch_size = 1 + num_frames = config.num_frames # 13 frames (9 obs + 4 rewind) + + # Image: (T, C, H, W) format as expected by processor + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + + # State: (T, D) format + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 150, # Middle of episode + "episode_index": 0, + "task": "pick up the cube", + }, + } + + # Run processor + result = processor(transition) + + # Verify output structure + obs = result[TransitionKey.OBSERVATION] + + # Check video features exist and have correct shape + assert "video_features" in obs + video_features = obs["video_features"] + assert video_features.shape[0] == batch_size + assert video_features.shape[1] == num_frames + assert video_features.shape[2] == 512 # CLIP embedding dim + + # Check state features exist and have correct shape + assert "state_features" in obs + state_features = obs["state_features"] + assert state_features.shape[0] == batch_size + assert state_features.shape[1] == num_frames + assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim + + # Check text features exist and have correct shape + assert "text_features" in obs + text_features = obs["text_features"] + assert text_features.shape[0] == batch_size + assert text_features.shape[1] == 512 # CLIP embedding dim + + # Check lengths tensor + assert "lengths" in obs + lengths = obs["lengths"] + assert lengths.shape[0] == batch_size + assert lengths.dtype == torch.int32 + + # Check sparse_targets exist + assert "sparse_targets" in obs + sparse_targets = obs["sparse_targets"] + assert sparse_targets.shape == (batch_size, num_frames) + # All targets should be in [0, max_stages] range (stage.tau format) + assert (sparse_targets >= 0).all() + + # Check dense_targets exist (for dual mode) + assert "dense_targets" in obs + dense_targets = obs["dense_targets"] + assert dense_targets.shape == (batch_size, num_frames) + assert (dense_targets >= 0).all() + + def test_call_with_batched_input(self, mock_clip_model): + """Test processor __call__ with a batched input (multiple frames) in dual mode.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["reach", "grasp"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["step1", "step2", "step3"], + dense_temporal_proportions=[0.33, 0.34, 0.33], + ) + + # Two episodes with different lengths, each with sparse+dense annotations + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 200, + "task": "task A", + "sparse_subtask_names": ["reach", "grasp"], + "sparse_subtask_start_frames": [0, 100], + "sparse_subtask_end_frames": [100, 200], + "dense_subtask_names": ["step1", "step2", "step3"], + "dense_subtask_start_frames": [0, 66, 133], + "dense_subtask_end_frames": [66, 133, 200], + }, + { + "dataset_from_index": 200, + "dataset_to_index": 500, + "task": "task B", + "sparse_subtask_names": ["reach", "grasp"], + "sparse_subtask_start_frames": [200, 350], + "sparse_subtask_end_frames": [350, 500], + "dense_subtask_names": ["step1", "step2", "step3"], + "dense_subtask_start_frames": [200, 300, 400], + "dense_subtask_end_frames": [300, 400, 500], + }, + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + batch_size = 2 + num_frames = config.num_frames + + # Image: (B, T, C, H, W) format + dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": np.array([100, 350]), # One frame from each episode + "episode_index": np.array([0, 1]), + "task": ["task A", "task B"], + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + + # Verify batch dimension is preserved for all outputs + assert obs["video_features"].shape[0] == batch_size + assert obs["state_features"].shape[0] == batch_size + assert obs["lengths"].shape[0] == batch_size + assert obs["sparse_targets"].shape[0] == batch_size + assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets + + def test_targets_increase_with_progress(self, mock_clip_model): + """Test that both sparse and dense targets increase as frame index progresses.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=30, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["phase1", "phase2"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["a", "b", "c", "d"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 300, + "task": "test task", + "sparse_subtask_names": ["phase1", "phase2"], + "sparse_subtask_start_frames": [0, 150], + "sparse_subtask_end_frames": [150, 300], + "dense_subtask_names": ["a", "b", "c", "d"], + "dense_subtask_start_frames": [0, 75, 150, 225], + "dense_subtask_end_frames": [75, 150, 225, 300], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at early, middle, and late points in episode + frame_indices = [30, 150, 270] + sparse_center_targets = [] + dense_center_targets = [] + + for frame_idx in frame_indices: + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": frame_idx, + "episode_index": 0, + "task": "test task", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + # Get target at center frame (index 4 in 9-frame observation window) + sparse_center_targets.append(obs["sparse_targets"][0, 4].item()) + dense_center_targets.append(obs["dense_targets"][0, 4].item()) + + # Both sparse and dense targets should increase with frame index + assert sparse_center_targets[0] < sparse_center_targets[2], ( + f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})" + ) + assert dense_center_targets[0] < dense_center_targets[2], ( + f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})" + ) + + def test_progress_labels_exact_values(self, mock_clip_model): + """Test that progress labels (stage.tau) are computed correctly for known positions.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + # Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, # Smaller gap for easier calculation + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["A", "B"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["d1", "d2", "d3", "d4"], + dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25], + ) + + # Episode: frames 0-99, sparse stages at [0-49], [50-99] + # Dense stages at [0-24], [25-49], [50-74], [75-99] + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 100, + "task": "test", + "sparse_subtask_names": ["A", "B"], + "sparse_subtask_start_frames": [0, 50], + "sparse_subtask_end_frames": [50, 100], + "dense_subtask_names": ["d1", "d2", "d3", "d4"], + "dense_subtask_start_frames": [0, 25, 50, 75], + "dense_subtask_end_frames": [25, 50, 75, 100], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at frame 50 (center of episode) + # With frame_gap=10, n_obs_steps=8: + # obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 50, + "episode_index": 0, + "task": "test", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + sparse_targets = obs["sparse_targets"][0] # (13,) + dense_targets = obs["dense_targets"][0] # (13,) + + # First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind) + # Check that obs frames have non-zero targets + obs_sparse = sparse_targets[:9] + obs_dense = dense_targets[:9] + + # Verify targets are monotonically increasing for observation frames + for i in range(1, 9): + assert obs_sparse[i] >= obs_sparse[i - 1], ( + f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}" + ) + assert obs_dense[i] >= obs_dense[i - 1], ( + f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}" + ) + + # Rewind slots should be zero when rewind is disabled + rewind_targets = sparse_targets[9:] + assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled" + + # Check stage transitions: frame 50 is at boundary of sparse stage A->B + # Center frame (index 4) corresponds to actual frame 50 + center_sparse = obs_sparse[4].item() + # At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0) + assert 0.9 <= center_sparse <= 1.1, ( + f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}" + ) + + def test_rewind_augmentation_applied(self, mock_clip_model): + """Test that rewind augmentation correctly extends sequence and generates targets.""" + import random + + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, + rewind_probability=1.0, # Always apply rewind + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["A", "B"], + sparse_temporal_proportions=[0.5, 0.5], + dense_subtask_names=["d1", "d2"], + dense_temporal_proportions=[0.5, 0.5], + ) + + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 200, + "task": "test", + "sparse_subtask_names": ["A", "B"], + "sparse_subtask_start_frames": [0, 100], + "sparse_subtask_end_frames": [100, 200], + "dense_subtask_names": ["d1", "d2"], + "dense_subtask_start_frames": [0, 100], + "dense_subtask_end_frames": [100, 200], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames # 13 + + # Test at frame 150 (center of bidirectional window) + # With n_obs_steps=8, half_steps=4, frame_gap=10: + # - Earliest obs frame = 150 - 4*10 = 110 + # - Rewind can go back from 110 to frames like 100, 90, 80, 70 + # - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 150, + "episode_index": 0, + "task": "test", + }, + } + + # Seed random for reproducibility + random.seed(42) + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + + lengths = obs["lengths"][0].item() + sparse_targets = obs["sparse_targets"][0] + + # With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind) + assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}" + assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}" + + # Rewind targets should be non-zero for frames within valid length + n_obs_frames = 9 + rewind_count = lengths - n_obs_frames + + if rewind_count > 0: + # Check that rewind frames have targets + rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count] + # Rewind frames are from BEFORE the earliest obs frame (110) + # These frames (100, 90, 80, 70) are earlier in the episode + earliest_obs_target = sparse_targets[0].item() # Frame 110 + + # Rewind targets should be less than earliest obs (they're from earlier frames) + for i, rt in enumerate(rewind_targets): + assert rt.item() < earliest_obs_target, ( + f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})" + ) + + # Rewind targets should be decreasing (going further back in time) + for i in range(1, len(rewind_targets)): + assert rewind_targets[i] <= rewind_targets[i - 1], ( + f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}" + ) + + def test_full_sequence_target_consistency(self, mock_clip_model): + """Test that the full sequence of targets is consistent with frame positions.""" + from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep + from lerobot.policies.sarm.sarm_utils import find_stage_and_tau + + config = MockConfig( + n_obs_steps=8, + max_rewind_steps=4, + frame_gap=10, + rewind_probability=0.0, + language_perturbation_probability=0.0, + annotation_mode="dual", + sparse_subtask_names=["s1", "s2", "s3"], + sparse_temporal_proportions=[0.33, 0.34, 0.33], + dense_subtask_names=["d1", "d2"], + dense_temporal_proportions=[0.5, 0.5], + ) + + # 3 sparse stages: [0-33), [33-66), [66-99] + # 2 dense stages: [0-50), [50-100) + episodes = [ + { + "dataset_from_index": 0, + "dataset_to_index": 100, + "task": "test", + "sparse_subtask_names": ["s1", "s2", "s3"], + "sparse_subtask_start_frames": [0, 33, 66], + "sparse_subtask_end_frames": [33, 66, 100], + "dense_subtask_names": ["d1", "d2"], + "dense_subtask_start_frames": [0, 50], + "dense_subtask_end_frames": [50, 100], + } + ] + dataset_meta = MockDatasetMeta(episodes) + + processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta) + processor.train(True) + + num_frames = config.num_frames + + # Test at frame 50 (middle of episode) + dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32) + dummy_state = np.random.rand(num_frames, 6).astype(np.float32) + + transition = { + TransitionKey.OBSERVATION: { + config.image_key: dummy_image, + config.state_key: dummy_state, + }, + TransitionKey.COMPLEMENTARY_DATA: { + "index": 50, + "episode_index": 0, + "task": "test", + }, + } + + result = processor(transition) + obs = result[TransitionKey.OBSERVATION] + sparse_targets = obs["sparse_targets"][0] + dense_targets = obs["dense_targets"][0] + + # Manually compute expected targets for observation frames + # With frame_gap=10, n_obs_steps=8, center at 50: + # obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90] + expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90] + + sparse_names = ["s1", "s2", "s3"] + sparse_starts = [0, 33, 66] + sparse_ends = [33, 66, 100] + sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33} + + dense_names = ["d1", "d2"] + dense_starts = [0, 50] + dense_ends = [50, 100] + dense_props = {"d1": 0.5, "d2": 0.5} + + for i, frame in enumerate(expected_obs_frames): + expected_sparse = find_stage_and_tau( + frame, + 100, + sparse_names, + sparse_starts, + sparse_ends, + sparse_names, + sparse_props, + return_combined=True, + ) + expected_dense = find_stage_and_tau( + frame, + 100, + dense_names, + dense_starts, + dense_ends, + dense_names, + dense_props, + return_combined=True, + ) + + actual_sparse = sparse_targets[i].item() + actual_dense = dense_targets[i].item() + + assert abs(actual_sparse - expected_sparse) < 0.01, ( + f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}" + ) + assert abs(actual_dense - expected_dense) < 0.01, ( + f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}" + ) diff --git a/tests/policies/test_sarm_subtask_annotations.py b/tests/policies/test_sarm_subtask_annotations.py new file mode 100644 index 000000000..0dc087288 --- /dev/null +++ b/tests/policies/test_sarm_subtask_annotations.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +pytest.importorskip("transformers") + +from lerobot.data_processing.sarm_annotations.subtask_annotation import ( + Subtask, + SubtaskAnnotation, + Timestamp, + compute_temporal_proportions, +) + + +def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation: + """Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec).""" + return SubtaskAnnotation( + subtasks=[ + Subtask( + name=name, + timestamps=Timestamp( + start=f"{start // 60:02d}:{start % 60:02d}", end=f"{end // 60:02d}:{end % 60:02d}" + ), + ) + for name, start, end in subtasks + ] + ) + + +class TestComputeTemporalProportions: + """Tests for compute_temporal_proportions (SARM Paper Formula 1). + + Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) + + Key insight: This averages the PROPORTION of each subtask within each trajectory, + giving equal weight to all trajectories regardless of absolute length. + """ + + def test_basic_two_trajectories_equal_proportions(self): + """Test with two trajectories that have equal proportions.""" + # Both trajectories: subtask1 = 50%, subtask2 = 50% + # Traj 1: T=100s, subtask1=50s, subtask2=50s + # Traj 2: T=200s, subtask1=100s, subtask2=100s + annotations = { + 0: make_annotation([("subtask1", 0, 50), ("subtask2", 50, 100)]), + 1: make_annotation([("subtask1", 0, 100), ("subtask2", 100, 200)]), + } + + result = compute_temporal_proportions(annotations) + + # Both should be 0.5 + assert abs(result["subtask1"] - 0.5) < 1e-6 + assert abs(result["subtask2"] - 0.5) < 1e-6 + + def test_paper_example_different_from_avg_durations(self): + """Test that compute_temporal_proportions differs from naive average duration approach. + + This is the key test showing the difference between: + - Paper formula: average of (L_i,k / T_i) + - Naive approach: mean(L_i,k) / sum(mean(L_i,j)) + """ + # Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2) + # Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8) + annotations = { + 0: make_annotation([("subtask1", 0, 80), ("subtask2", 80, 100)]), + 1: make_annotation([("subtask1", 0, 40), ("subtask2", 40, 200)]), + } + + result = compute_temporal_proportions(annotations) + + # Paper formula: + # ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5 + # ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5 + assert abs(result["subtask1"] - 0.5) < 1e-6 + assert abs(result["subtask2"] - 0.5) < 1e-6 + + def test_single_trajectory(self): + """Test with a single trajectory.""" + # T=100s, reach=30s, grasp=20s, lift=50s + annotations = { + 0: make_annotation([("reach", 0, 30), ("grasp", 30, 50), ("lift", 50, 100)]), + } + + result = compute_temporal_proportions(annotations) + + assert abs(result["reach"] - 0.3) < 1e-6 + assert abs(result["grasp"] - 0.2) < 1e-6 + assert abs(result["lift"] - 0.5) < 1e-6 + + def test_sum_to_one(self): + """Test that proportions always sum to 1.""" + # Three episodes with varying proportions + annotations = { + 0: make_annotation([("a", 0, 10), ("b", 10, 50), ("c", 50, 100)]), # 0.1, 0.4, 0.5 + 1: make_annotation([("a", 0, 20), ("b", 20, 70), ("c", 70, 100)]), # 0.2, 0.5, 0.3 + 2: make_annotation([("a", 0, 30), ("b", 30, 90), ("c", 90, 100)]), # 0.3, 0.6, 0.1 + } + + result = compute_temporal_proportions(annotations) + + total = sum(result.values()) + assert abs(total - 1.0) < 1e-6 + + def test_empty_annotations_returns_empty(self): + """Test that empty annotations returns empty dict.""" + result = compute_temporal_proportions({}) + assert result == {} + + def test_uniform_proportions(self): + """Test with uniform proportions across subtasks.""" + # Each subtask takes 25% of each episode + annotations = { + 0: make_annotation([("a", 0, 25), ("b", 25, 50), ("c", 50, 75), ("d", 75, 100)]), + 1: make_annotation([("a", 0, 50), ("b", 50, 100), ("c", 100, 150), ("d", 150, 200)]), + } + + result = compute_temporal_proportions(annotations) + + for name in ["a", "b", "c", "d"]: + assert abs(result[name] - 0.25) < 1e-6 diff --git a/tests/policies/test_sarm_utils.py b/tests/policies/test_sarm_utils.py new file mode 100644 index 000000000..510477ec8 --- /dev/null +++ b/tests/policies/test_sarm_utils.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.policies.sarm.sarm_utils import ( + apply_rewind_augmentation, + compute_absolute_indices, + compute_tau, + find_stage_and_tau, + normalize_stage_tau, + temporal_proportions_to_breakpoints, +) + + +class TestProgressLabelsWithModes: + """End-to-end tests for progress label generation in different modes.""" + + def test_sparse_mode_single_stage(self): + """Sparse mode with single stage should give linear progress.""" + episode_length = 300 + global_names = ["task"] + proportions = {"task": 1.0} + + # Test at various frames + for frame in [0, 100, 200, 299]: + stage, tau = find_stage_and_tau( + frame, episode_length, None, None, None, global_names, proportions + ) + + expected_tau = frame / (episode_length - 1) + assert stage == 0 + assert abs(tau - expected_tau) < 1e-5 + + def test_sparse_mode_multi_stage(self): + """Sparse mode with multiple stages.""" + global_names = ["reach", "grasp", "lift", "place"] + proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3} + + subtask_names = ["reach", "grasp", "lift", "place"] + subtask_starts = [0, 60, 120, 210] + subtask_ends = [59, 119, 209, 299] + + # Check stages are correctly identified + stage_at_30, _ = find_stage_and_tau( + 30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_30 == 0 + + stage_at_90, _ = find_stage_and_tau( + 90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_90 == 1 + + stage_at_150, _ = find_stage_and_tau( + 150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage_at_150 == 2 + + def test_dense_mode_more_stages(self): + """Dense mode should work with more fine-grained stages.""" + global_names = ["a", "b", "c", "d", "e", "f", "g", "h"] + proportions = dict.fromkeys(global_names, 1 / 8) + + subtask_names = global_names + subtask_starts = [i * 50 for i in range(8)] + subtask_ends = [(i + 1) * 50 - 1 for i in range(8)] + + # Each stage should occupy 50 frames + for stage_idx in range(8): + mid_frame = stage_idx * 50 + 25 + stage, _ = find_stage_and_tau( + mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == stage_idx + + +class TestComputeAbsoluteIndices: + """Tests for compute_absolute_indices (bidirectional sampling).""" + + def test_no_clamping_when_in_middle(self): + """When frame is in middle of episode, no clamping should occur.""" + frame_idx = 300 + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # All should be valid (no out of bounds) + assert out_of_bounds.sum() == 0 + + # Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center + half_steps = n_obs_steps // 2 + expected = ( + [frame_idx - frame_gap * i for i in range(half_steps, 0, -1)] + + [frame_idx] + + [frame_idx + frame_gap * i for i in range(1, half_steps + 1)] + ) + assert indices.tolist() == expected + + # Center frame (index 4) should be frame_idx + assert indices[half_steps] == frame_idx + + def test_clamping_at_episode_start(self): + """Early frames should be clamped to episode start.""" + frame_idx = 50 # Not enough history for full past window + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # Some past frames should be clamped (out_of_bounds = 1) + assert out_of_bounds.sum() > 0 + + # All indices should be >= ep_start + assert (indices >= ep_start).all() + + # Center index should be frame_idx + half_steps = n_obs_steps // 2 + assert indices[half_steps] == frame_idx + + def test_clamping_at_episode_end(self): + """Late frames should be clamped to episode end.""" + frame_idx = 950 # Not enough future for full window + ep_start = 0 + ep_end = 1000 + n_obs_steps = 8 + frame_gap = 30 + + indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap) + + # Some future frames should be clamped + assert out_of_bounds.sum() > 0 + + # All indices should be < ep_end + assert (indices < ep_end).all() + + # Center index should be frame_idx + half_steps = n_obs_steps // 2 + assert indices[half_steps] == frame_idx + + def test_sequence_is_monotonic(self): + """Frame indices should be monotonically increasing.""" + for frame_idx in [50, 100, 300, 950]: + indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30) + + # Check monotonic (non-decreasing due to clamping) + diffs = indices[1:] - indices[:-1] + assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}" + + +class TestComputeTau: + """Tests for compute_tau (within-subtask progress). + + Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1] + """ + + def test_at_start(self): + """τ should be 0 at subtask start.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_at_end(self): + """τ should be 1 at subtask end.""" + tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_at_middle(self): + """τ should be 0.5 at subtask midpoint.""" + tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50) + assert abs(tau - 0.5) < 1e-6 + + def test_quarter_progress(self): + """Test τ at 25% through subtask.""" + tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80) + assert abs(tau - 0.25) < 1e-6 + + def test_zero_duration_subtask(self): + """τ should be 1.0 for zero-duration subtask.""" + tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10) + assert tau == 1.0 + + def test_clamps_below_zero(self): + """τ should be clamped to 0 if frame is before subtask.""" + tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50) + assert tau == 0.0 + + def test_clamps_above_one(self): + """τ should be clamped to 1 if frame is after subtask.""" + tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50) + assert tau == 1.0 + + def test_float_inputs(self): + """Test with float frame indices (from interpolation).""" + tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0) + expected = (25.5 - 10.0) / (50.0 - 10.0) + assert abs(tau - expected) < 1e-6 + + +class TestFindStageAndTau: + """Tests for find_stage_and_tau logic. + + This function is the core of progress label computation. It determines + which stage a frame belongs to and the within-stage progress (tau). + """ + + def test_single_stage_mode_linear_progress(self): + """Single-stage mode should give linear progress from 0 to 1.""" + episode_length = 100 + + # Frame 0 -> tau = 0 + stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 0.0) < 1e-6 + + # Frame 50 -> tau = 0.505 (50/99) + stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 50 / 99) < 1e-6 + + # Frame 99 -> tau = 1.0 + stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0}) + assert stage == 0 + assert abs(tau - 1.0) < 1e-6 + + def test_multi_stage_within_subtask(self): + """Test finding stage when frame is within a subtask.""" + global_names = ["reach", "grasp", "lift"] + proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5} + + subtask_names = ["reach", "grasp", "lift"] + subtask_starts = [0, 30, 50] + subtask_ends = [29, 49, 99] + + # Frame 15 in "reach" stage (index 0) + stage, tau = find_stage_and_tau( + 15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert abs(tau - 15 / 29) < 1e-6 + + # Frame 40 in "grasp" stage (index 1) + stage, tau = find_stage_and_tau( + 40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 + # tau = (40 - 30) / (49 - 30) = 10/19 + assert abs(tau - 10 / 19) < 1e-6 + + # Frame 75 in "lift" stage (index 2) + stage, tau = find_stage_and_tau( + 75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 2 + # tau = (75 - 50) / (99 - 50) = 25/49 + assert abs(tau - 25 / 49) < 1e-6 + + def test_frame_at_subtask_boundaries(self): + """Test frames exactly at subtask boundaries.""" + global_names = ["a", "b"] + proportions = {"a": 0.5, "b": 0.5} + + subtask_names = ["a", "b"] + subtask_starts = [0, 50] + subtask_ends = [49, 99] + + # Frame at start of first subtask + stage, tau = find_stage_and_tau( + 0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert tau == 0.0 + + # Frame at end of first subtask + stage, tau = find_stage_and_tau( + 49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 0 + assert tau == 1.0 + + # Frame at start of second subtask + stage, tau = find_stage_and_tau( + 50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 + assert tau == 0.0 + + def test_frame_after_last_subtask(self): + """Frames after last subtask should return last stage with high tau.""" + global_names = ["a", "b"] + proportions = {"a": 0.5, "b": 0.5} + + subtask_names = ["a", "b"] + subtask_starts = [0, 30] + subtask_ends = [29, 59] + + # Frame 80 is after last subtask + stage, tau = find_stage_and_tau( + 80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions + ) + assert stage == 1 # Last stage + assert tau == 0.999 # Nearly complete + + +class TestEndToEndProgressLabeling: + """End-to-end tests for progress label computation using normalize_stage_tau.""" + + def test_consistent_semantic_meaning(self): + """Test that same subtask completion maps to same progress across trajectories. + + This is the key semantic property: "end of subtask 1" should always + mean the same progress value regardless of trajectory speed. + """ + proportions = [0.3, 0.5, 0.2] + + # Fast trajectory: subtask 1 ends at frame 30 (of 100) + tau_fast = compute_tau(30, 0, 30) # = 1.0 + y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions) + + # Slow trajectory: subtask 1 ends at frame 90 (of 300) + tau_slow = compute_tau(90, 0, 90) # = 1.0 + y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions) + + # Both should map to same progress (0.3 = end of subtask 1) + assert abs(y_fast - y_slow) < 1e-6 + assert abs(y_fast - 0.3) < 1e-6 + + def test_monotonic_within_subtask(self): + """Test that progress is monotonically increasing within a subtask.""" + proportions = [0.4, 0.6] + + prev_y = -1 + for tau in np.linspace(0, 1, 11): + y = normalize_stage_tau(0 + tau, temporal_proportions=proportions) + assert y > prev_y or (tau == 0 and y == 0) + prev_y = y + + def test_continuous_across_subtasks(self): + """Test that progress is continuous at subtask boundaries.""" + proportions = [0.3, 0.5, 0.2] + + # End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0 + y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions) + + # Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0 + y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions) + + # Should be equal (P_1 = 0.3) + assert abs(y_end_0 - y_start_1) < 1e-6 + + # End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0 + y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions) + + # Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0 + y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions) + + # Should be equal (P_2 = 0.8) + assert abs(y_end_1 - y_start_2) < 1e-6 + + +class TestTemporalProportionsToBreakpoints: + """Tests for temporal_proportions_to_breakpoints. + + Converts temporal proportions to cumulative breakpoints for normalization. + Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0] + """ + + def test_basic_conversion(self): + """Test basic conversion from proportions to breakpoints.""" + proportions = [0.3, 0.5, 0.2] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + assert breakpoints is not None + assert len(breakpoints) == 4 + assert breakpoints[0] == 0.0 + assert abs(breakpoints[1] - 0.3) < 1e-6 + assert abs(breakpoints[2] - 0.8) < 1e-6 + assert breakpoints[3] == 1.0 + + def test_dict_input(self): + """Test with dict input.""" + proportions = {"a": 0.25, "b": 0.25, "c": 0.5} + breakpoints = temporal_proportions_to_breakpoints(proportions) + + assert breakpoints is not None + assert len(breakpoints) == 4 + assert breakpoints[0] == 0.0 + assert breakpoints[-1] == 1.0 + + def test_dict_with_subtask_names_order(self): + """Test that subtask_names determines order for dict input.""" + proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order + subtask_names = ["a", "b", "c"] # Different order + + breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names) + + # Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5 + assert abs(breakpoints[1] - 0.2) < 1e-6 # a + assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5 + assert breakpoints[3] == 1.0 # a + b + c = 1.0 + + def test_uniform_proportions(self): + """Test with uniform proportions.""" + proportions = [0.25, 0.25, 0.25, 0.25] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + expected = [0.0, 0.25, 0.5, 0.75, 1.0] + for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)): + assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch" + + def test_none_input(self): + """Test that None input returns None.""" + result = temporal_proportions_to_breakpoints(None) + assert result is None + + def test_normalization(self): + """Test that non-normalized proportions are normalized.""" + # Proportions sum to 2.0, not 1.0 + proportions = [0.6, 1.0, 0.4] + breakpoints = temporal_proportions_to_breakpoints(proportions) + + # Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0] + assert breakpoints[-1] == 1.0 + assert abs(breakpoints[1] - 0.3) < 1e-6 + + +class TestNormalizeStageTau: + """Tests for normalize_stage_tau. + + Normalizes stage+tau values to [0, 1] using breakpoints. + """ + + def test_linear_fallback(self): + """Test linear normalization when only num_stages is provided.""" + # 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0] + + # Stage 0 start + assert normalize_stage_tau(0.0, num_stages=4) == 0.0 + + # Stage 0 end / Stage 1 start + assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6 + + # Stage 1 middle + assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6 + + # Stage 3 end + assert normalize_stage_tau(4.0, num_stages=4) == 1.0 + + def test_with_custom_breakpoints(self): + """Test with custom breakpoints.""" + # Non-linear breakpoints + breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages + + # Stage 0: maps [0, 1) to [0.0, 0.1) + assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6 + + # Stage 1: maps [1, 2) to [0.1, 0.5) + assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6 + + # Stage 2: maps [2, 3) to [0.5, 1.0) + assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6 + + def test_with_temporal_proportions(self): + """Test with temporal proportions (auto-computed breakpoints).""" + proportions = {"a": 0.2, "b": 0.3, "c": 0.5} + subtask_names = ["a", "b", "c"] + + # Stage 0 end should map to 0.2 + result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names) + assert abs(result - 0.2) < 1e-6 + + # Stage 1 end should map to 0.5 + result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names) + assert abs(result - 0.5) < 1e-6 + + def test_tensor_input(self): + """Test with tensor input.""" + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0]) + breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages + + result = normalize_stage_tau(x, breakpoints=breakpoints) + + assert isinstance(result, torch.Tensor) + assert result.shape == x.shape + assert abs(result[0].item() - 0.0) < 1e-6 + assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0 + assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1 + + def test_clamping(self): + """Test that output is clamped to [0, 1].""" + # Below 0 + assert normalize_stage_tau(-0.5, num_stages=4) == 0.0 + + # Above num_stages + assert normalize_stage_tau(5.0, num_stages=4) == 1.0 + + def test_batch_tensor(self): + """Test with batched tensor.""" + x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3) + + result = normalize_stage_tau(x, num_stages=3) + + assert result.shape == (2, 3) + assert (result >= 0).all() + assert (result <= 1).all() + + def test_requires_one_of_inputs(self): + """Test that at least one input method is required.""" + with pytest.raises(ValueError): + normalize_stage_tau(1.0) + + +class TestRewindAugmentation: + """Tests for rewind augmentation logic with bidirectional observation sampling. + + Rewind appends frames before the earliest observation frame, going backwards. + With bidirectional sampling centered at frame_idx: + - Earliest obs frame = frame_idx - half_steps * frame_gap + - Rewind goes backwards from that point + """ + + def test_rewind_indices_go_backwards_from_earliest_obs(self): + """Rewind indices should go backwards from earliest observation frame.""" + frame_idx = 300 # Center of bidirectional window + ep_start = 0 + n_obs_steps = 4 # half_steps = 2 + frame_gap = 30 + + # Earliest obs frame = 300 - 2*30 = 240 + # Rewind goes backwards: 210, 180 + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps=n_obs_steps, + max_rewind_steps=2, + frame_gap=frame_gap, + rewind_step=2, + ) + + assert rewind_step == 2 + assert len(rewind_indices) == 2 + # First rewind frame is closest to obs window, second is further back + assert rewind_indices[0] == 210 # 240 - 30 + assert rewind_indices[1] == 180 # 240 - 60 + assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending" + + def test_rewind_goes_backward_through_history(self): + """Rewind frames should go backward before the observation window.""" + frame_idx = 450 # Center of bidirectional window + ep_start = 0 + n_obs_steps = 8 # half_steps = 4 + frame_gap = 30 + + # Earliest obs frame = 450 - 4*30 = 330 + # Rewind from 330: [300, 270, 240] + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, + ep_start, + n_obs_steps=n_obs_steps, + max_rewind_steps=4, + frame_gap=frame_gap, + rewind_step=3, + ) + + assert rewind_step == 3 + expected = [300, 270, 240] # Going backwards from 330 + assert rewind_indices == expected + + def test_no_rewind_when_obs_window_at_episode_start(self): + """No rewind when observation window reaches episode start.""" + frame_idx = 120 # Center of window + ep_start = 0 + n_obs_steps = 8 # half_steps = 4 + frame_gap = 30 + + # Earliest obs frame = 120 - 4*30 = 0 (at episode start) + rewind_step, rewind_indices = apply_rewind_augmentation( + frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap + ) + + # No room for rewind + assert rewind_step == 0 + assert rewind_indices == [] + + def test_rewind_targets_are_decreasing(self): + """Progress targets for rewind frames should be decreasing.""" + # Simulate progress values + obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress + + # Rewind reverses progress + rewind_indices = [4, 3, 2] # Go backwards through indices + rewind_progress = [obs_progress[i] for i in rewind_indices] + + # Should be decreasing + for i in range(len(rewind_progress) - 1): + assert rewind_progress[i] > rewind_progress[i + 1]