mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Add sarm (#2639)
* add initial modeling * make rewind pretrained policy * add annotation * small fix * add sarm * subtasks * fix spawn * fix rewind discrepancies * Add script to generate embedding for dataset (#2138) * Add generate and validate script * fix precommit * Improve generate embeddings function by using dataset tools (#2206) --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> * cleanup * change order train log * print batch size * update sarm processor * add reward output * change expected features * add image validation * change validation * get state input from dataset stats * raise if no state key is found * pass stats * cleanup and refactor * add episode inddex to complementary data * add subtask init and detection * revert lerobot_train changes * pass dataset metadata to policy * change loadig subtasks * add small logging * fix progress conversion and adding initial frame * use large offset for initial frame (ugly) * Remove rewind, use clip tokenizer * add tests, implement formula 1,2 correctly and cleanup * use task from dataset, cleanup visualizer * simplify * simplify and cleanup code and move compute_temporal_proportions to utils * fix normalization in visualization * Fix visualization and change prompt * fix formatting * add visualize subtask annotations * use qwen thinking * try different prompt * format * update prompt * higher temp, long output * different settings * use instruct * show full resp * split message * Temp: increase tolerance dataset * Fix RA-BC (#2572) * Add next observation loading for RA-BC progress deltas * Compute weights based on temporal progress deltas instead of static rewards * Add hard-masking for negative progress deltas in weight computation * Feat/add dual head (#2582) * Add dual dense sparse head and annotation * Add docs * add dual to procesor * cleanup * change sampling in visualize and cleanup * remove validation * remove compile * Feat/test uniform (#2587) * test uniform * add different string for misaligned * Fix rewind and add tests * uncomment text implementation * run precommit * Add head mode for ra-bc * fix visalization of single task * add * return per sample loss * Fix RA_BC (#2602) * update rabc implementation * compute rabc beforehand * fix import * add only progress calulation * use precomputed progress * multi gpu processing * import * fix dataset meta data extraction * add logging * logging * log * progress per episode * split differently * move clip to gpu * pre decode frames for an episode * fix cuda initalization * fix import * multi processing * rename * fix import * fix * fix rabc * use last known progress if oob * use last known progress if oob * add misalignment loss with random embeddings * discard previous changes * add selection of models to docs for ra_bc * add transformers dep * extend tolerance * initial commit with new codebase * add tests * fix * remove temporal sampler * drop last frame for sampler * use original ref * some fixes * fix visualization * remove smoothing and fix order subtasks * add stride rabc computation * add push to hub * add explanation * add kappa expllaination * better rabc logging * feedback pr * remove dataset tolerance * revert dataset tool * revert dataset changes * add credit * run precommit * change path for generate ra_bc * fix type * include sarm in all in pyproject * fix precommit * lazy import matplotlib * lazy import qwen * remove rich console * skip if transformers is not installed? * run only when we have faker * place transformer lazy loading * Dont test if low transformer version * fix * increase transformer * increase as 4.57.0 is yanked * remove pi from all * go back --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
@@ -42,6 +42,10 @@
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
|
||||
@@ -0,0 +1,586 @@
|
||||
# SARM: Stage-Aware Reward Modeling
|
||||
|
||||
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
|
||||
|
||||
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||
|
||||
## Why Reward Models?
|
||||
|
||||
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||
|
||||
## Overview
|
||||
|
||||
SARM has following features:
|
||||
|
||||
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
|
||||
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
|
||||
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
|
||||
|
||||
SARM trains on a compact **stage+tau** target for each frame:
|
||||
|
||||
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
|
||||
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
|
||||
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
|
||||
|
||||
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
|
||||
|
||||
This matches **Formula (2)** from the paper:
|
||||
|
||||
```
|
||||
progress_t = P_{k-1} + α̅_k × τ_t
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
|
||||
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
|
||||
- `α̅_k` is the temporal proportion for subtask k
|
||||
|
||||
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
|
||||
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
|
||||
|
||||
At minimum, each training sample needs:
|
||||
|
||||
- `task` (string): task description
|
||||
- `policy.image_key` images and `policy.state_key` states from the dataset
|
||||
|
||||
---
|
||||
|
||||
## Annotation Modes
|
||||
|
||||
You can choose from **3 annotation modes** that determine how progress labels are computed:
|
||||
|
||||
| Mode | Annotations Required | Heads | Use Case |
|
||||
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
|
||||
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
|
||||
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
|
||||
|
||||
### Mode Details
|
||||
|
||||
<hfoptions id="mode_explanation">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress
|
||||
- **Dense head**: Not used
|
||||
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
|
||||
|
||||
## Set Up Your Environment
|
||||
|
||||
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||
2. Install SARM dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[sarm]"
|
||||
```
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
|
||||
|
||||
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
|
||||
- **Dense head**: Multiple fine-grained stages from VLM annotations
|
||||
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
|
||||
|
||||
- **Sparse head**: High-level stages from VLM annotations
|
||||
- **Dense head**: Fine-grained stages from VLM annotations
|
||||
- **Best for**: Complex multi-stage tasks where both granularities are useful
|
||||
|
||||
Workflow:
|
||||
|
||||
```
|
||||
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
---
|
||||
|
||||
## Step 1: Subtask Annotation
|
||||
|
||||
<hfoptions id="annotation_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--dense-only \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
|
||||
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||
--video-key observation.images.base \
|
||||
--num-workers 4 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
|
||||
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
|
||||
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||
- Per-episode columns in `episodes/*.parquet`:
|
||||
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
|
||||
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Annotation Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | ------------------------------------------------------------------------------- |
|
||||
| `--repo-id` | HuggingFace dataset repository ID |
|
||||
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
|
||||
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
|
||||
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
|
||||
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
|
||||
| `--num-workers` | Number of parallel GPU workers (default: 1) |
|
||||
| `--episodes` | Specific episode indices to annotate (default: all) |
|
||||
| `--skip-existing` | Skip episodes that already have annotations |
|
||||
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
|
||||
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
|
||||
|
||||
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Verify Annotations
|
||||
|
||||
<hfoptions id="verify_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
**No verification needed!** Skip this step.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type dense \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Visualize annotations using the `--visualize-only` flag:
|
||||
|
||||
```bash
|
||||
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--visualize-only \
|
||||
--visualize-type both \
|
||||
--num-visualizations 5 \
|
||||
--video-key observation.images.base \
|
||||
--output-dir ./subtask_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | -------------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize existing annotations (no generation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
|
||||
|
||||
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Train SARM
|
||||
|
||||
<hfoptions id="train_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=single_stage \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_single \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
Train with **dense annotations only** (sparse auto-generated):
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dense_only \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dense \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
Train with **both sparse and dense annotations**:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=sarm \
|
||||
--policy.annotation_mode=dual \
|
||||
--policy.image_key=observation.images.base \
|
||||
--output_dir=outputs/train/sarm_dual \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=sarm \
|
||||
--policy.repo_id=your-username/your-model-name
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
|
||||
|
||||
### Training Arguments
|
||||
|
||||
| Argument | Description | Default |
|
||||
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
|
||||
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
|
||||
| `--policy.image_key` | Camera key for images | `observation.images.top` |
|
||||
| `--policy.state_key` | Key for joint states | `observation.state` |
|
||||
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
|
||||
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Visualize Predictions
|
||||
|
||||
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
|
||||
|
||||
<hfoptions id="viz_mode">
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode sparse \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode dense \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
--num-visualizations 5 \
|
||||
--head-mode both \
|
||||
--output-dir ./sarm_viz
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The visualization shows:
|
||||
|
||||
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
|
||||
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
|
||||
- **Sample frames**: Key frames from the episode with progress/stage labels
|
||||
|
||||
### Visualization Arguments
|
||||
|
||||
| Argument | Description |
|
||||
| ---------------------- | --------------------------------------------------------- |
|
||||
| `--visualize-only` | Only visualize predictions (no RABC computation) |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
|
||||
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
|
||||
|
||||
---
|
||||
|
||||
## Step 5 (Optional): Train Policy with RA-BC
|
||||
|
||||
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
|
||||
|
||||
1. **Precompute progress values** for all frames using the trained SARM model
|
||||
2. **Train policy** with RA-BC weighting using the precomputed values
|
||||
|
||||
### How RA-BC Works
|
||||
|
||||
For each training sample, RA-BC computes the progress delta:
|
||||
|
||||
```
|
||||
r_i = φ(o_{t+Δ}) - φ(o_t)
|
||||
```
|
||||
|
||||
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
|
||||
|
||||
The weighting follows **Equations 8-9** from the paper:
|
||||
|
||||
- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
|
||||
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
|
||||
|
||||
### Step 5a: Compute SARM Progress Values
|
||||
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
--num-visualizations 5 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
This script:
|
||||
|
||||
- Processes all frames and computes progress values
|
||||
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
|
||||
- Generates visualizations of the first N episodes (default: 5)
|
||||
|
||||
**Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | -------------------------------------------------------------- | ---------- |
|
||||
| `--reward-model-path` | Path to trained SARM model | (required) |
|
||||
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
|
||||
| `--device` | Device for inference | `cuda` |
|
||||
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
|
||||
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
|
||||
|
||||
**Output format** (`sarm_progress.parquet`):
|
||||
|
||||
| Column | Description |
|
||||
| ----------------- | ---------------------------------------------- |
|
||||
| `index` | Global frame index in dataset |
|
||||
| `episode_index` | Episode number |
|
||||
| `frame_index` | Local frame index within episode |
|
||||
| `progress_sparse` | Sparse head progress value [0, 1] |
|
||||
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
The training script automatically:
|
||||
|
||||
- Loads the precomputed progress values from the parquet file
|
||||
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
|
||||
- Computes sample weights based on progress improvement
|
||||
- Applies weighted loss during training
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
|
||||
|
||||
**How the weighting works:**
|
||||
|
||||
| Condition | Weight |
|
||||
| ------------------- | ----------------------- |
|
||||
| `delta > kappa` | 1.0 (hard threshold) |
|
||||
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
|
||||
| `delta < 0` | 0.0 (negative progress) |
|
||||
|
||||
**Diagnosing kappa issues:**
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
|
||||
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
|
||||
|
||||
### Multi-GPU Training with RA-BC
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tips & Best Practices
|
||||
|
||||
### Choosing a Mode
|
||||
|
||||
- **Start with `single_stage`** for quick experiments - no annotation overhead
|
||||
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
|
||||
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
|
||||
|
||||
### Annotation Quality
|
||||
|
||||
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
|
||||
2. **Verify with visualization**: Always check a few episodes before training
|
||||
3. **Consistent naming**: Use the same subtask names across all episodes
|
||||
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chen2025sarm,
|
||||
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||
journal={arXiv preprint arXiv:2509.25358},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
+3
-1
@@ -96,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||
|
||||
# Motors
|
||||
@@ -133,6 +133,7 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
@@ -173,6 +174,7 @@ all = [
|
||||
"lerobot[phone]",
|
||||
"lerobot[libero]",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -65,9 +65,17 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
@@ -131,6 +139,14 @@ class TrainPipelineConfig(HubMixin):
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,6 +29,7 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"SmolVLAConfig",
|
||||
"SARMConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
|
||||
@@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -37,6 +37,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
@@ -105,6 +106,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -337,6 +342,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -435,6 +448,13 @@ def make_policy(
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
if ds_meta is not None and hasattr(ds_meta, "stats"):
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
|
||||
if ds_meta is not None:
|
||||
kwargs["dataset_meta"] = ds_meta
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
||||
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
name = "groot"
|
||||
config_class = GrootConfig
|
||||
|
||||
def __init__(self, config: GrootConfig):
|
||||
def __init__(self, config: GrootConfig, **kwargs):
|
||||
"""Initialize Groot policy wrapper."""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
@@ -907,6 +907,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1235,9 +1236,15 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1251,11 +1258,17 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
return loss, loss_dict
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -880,6 +880,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI05Config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -1209,9 +1210,15 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training."""
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training.
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
@@ -1225,11 +1232,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
return loss, loss_dict
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -0,0 +1,870 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
|
||||
|
||||
This script processes all frames in a dataset with SARM to compute progress values [0, 1].
|
||||
The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
|
||||
|
||||
Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
|
||||
need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 5
|
||||
|
||||
The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read reward_model_path from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def load_sarm_resources(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
device: str = "cuda",
|
||||
) -> tuple[LeRobotDataset, SARMRewardModel, any]:
|
||||
"""
|
||||
Load SARM model, dataset, and preprocessor.
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset, reward_model, preprocessor)
|
||||
"""
|
||||
logging.info(f"Loading model: {reward_model_path}")
|
||||
reward_model = SARMRewardModel.from_pretrained(reward_model_path)
|
||||
reward_model.config.device = device
|
||||
reward_model.to(device).eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
delta_indices = reward_model.config.observation_delta_indices
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
fps = temp_dataset.fps
|
||||
|
||||
delta_timestamps = {
|
||||
image_key: [idx / fps for idx in delta_indices],
|
||||
state_key: [idx / fps for idx in delta_indices],
|
||||
}
|
||||
dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
preprocess, _ = make_sarm_pre_post_processors(
|
||||
config=reward_model.config,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
|
||||
return dataset, reward_model, preprocess
|
||||
|
||||
|
||||
def to_numpy_image(img) -> np.ndarray:
|
||||
"""Convert image tensor to numpy uint8 (H, W, C)."""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
if img.ndim == 4:
|
||||
# Take center frame for bidirectional sampling
|
||||
img = img[img.shape[0] // 2]
|
||||
if img.shape[0] in [1, 3]:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
if img.dtype != np.uint8:
|
||||
# Handle normalized images (may have negative values or values > 1)
|
||||
img = img.astype(np.float32)
|
||||
img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
|
||||
img = (img * 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
|
||||
):
|
||||
"""Create visualization with progress plot, stage probabilities, and sample frames.
|
||||
|
||||
Same as sarm_inference_visualization.py
|
||||
"""
|
||||
num_stages = stage_preds.shape[1]
|
||||
colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
frame_indices = np.arange(len(progress_preds))
|
||||
|
||||
fig = plt.figure(figsize=(14, 12))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
|
||||
|
||||
# Progress plot
|
||||
ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
|
||||
ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
|
||||
if gt_progress is not None:
|
||||
ax_progress.plot(
|
||||
frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
|
||||
)
|
||||
ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
|
||||
ax_progress.set_ylabel("Progress")
|
||||
ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc="upper left")
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
|
||||
# Stage predictions
|
||||
ax_stages.stackplot(
|
||||
frame_indices,
|
||||
*[stage_preds[:, i] for i in range(num_stages)],
|
||||
colors=colors,
|
||||
alpha=0.8,
|
||||
labels=stage_labels,
|
||||
)
|
||||
if gt_stages is not None:
|
||||
for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
|
||||
ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
|
||||
ax_stages.set_xlabel("Frame")
|
||||
ax_stages.set_ylabel("Stage Probability")
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
|
||||
# Sample frames
|
||||
ax_frames.axis("off")
|
||||
num_sample = 8
|
||||
sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
|
||||
h, w = frames[0].shape[:2]
|
||||
combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
|
||||
for i, idx in enumerate(sample_indices):
|
||||
frame = frames[idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
combined[:, i * w : (i + 1) * w] = frame
|
||||
stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
|
||||
ax_frames.text(
|
||||
i * w + w / 2,
|
||||
-10,
|
||||
f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=7,
|
||||
)
|
||||
ax_frames.imshow(combined)
|
||||
ax_frames.set_title("Sample Frames", pad=20)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def visualize_sarm_predictions(
|
||||
dataset: LeRobotDataset,
|
||||
reward_model: SARMRewardModel,
|
||||
preprocess,
|
||||
episode_indices: list[int],
|
||||
head_mode: str,
|
||||
output_dir: Path,
|
||||
num_display_frames: int = 5,
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Visualize SARM predictions for multiple episodes.
|
||||
|
||||
Computes predictions for every frame by default. With stride > 1, computes predictions
|
||||
every N frames and interpolates (progress + stage probabilities) for visualization.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset with delta_timestamps configured
|
||||
reward_model: Loaded SARM model
|
||||
preprocess: Preprocessor from make_sarm_pre_post_processors
|
||||
episode_indices: List of episode indices to visualize
|
||||
head_mode: "sparse", "dense", or "both"
|
||||
output_dir: Directory to save visualizations
|
||||
num_display_frames: Number of frames to display in thumbnail strip (default: 5)
|
||||
stride: Compute predictions every N frames, interpolate the rest (default: 1)
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
device = reward_model.device
|
||||
|
||||
# Center frame index for bidirectional sampling
|
||||
target_idx = reward_model.config.n_obs_steps // 2
|
||||
|
||||
# Determine which heads to visualize
|
||||
schemes_to_viz = []
|
||||
if head_mode in ("sparse", "both") or not dual_mode:
|
||||
schemes_to_viz.append("sparse")
|
||||
if head_mode in ("dense", "both") and dual_mode:
|
||||
schemes_to_viz.append("dense")
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
for episode_idx in episode_indices:
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Select frames for display thumbnails (evenly sampled from begin to end)
|
||||
display_indices = set(
|
||||
[
|
||||
ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
|
||||
for i in range(num_display_frames)
|
||||
]
|
||||
if num_frames >= num_display_frames
|
||||
else list(range(ep_start, ep_end))
|
||||
)
|
||||
viz_frames = {}
|
||||
|
||||
# Load display frames up-front (stride mode might skip them otherwise).
|
||||
for frame_idx in display_indices:
|
||||
sample = dataset[frame_idx]
|
||||
viz_frames[frame_idx] = to_numpy_image(sample[image_key])
|
||||
|
||||
# Initialize storage for each scheme
|
||||
scheme_data = {}
|
||||
for scheme in schemes_to_viz:
|
||||
num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
|
||||
scheme_data[scheme] = {
|
||||
"viz_progress": np.full(num_frames, np.nan),
|
||||
"viz_stages": np.full((num_frames, num_stages), np.nan),
|
||||
"viz_gt_progress": np.full(num_frames, np.nan),
|
||||
"viz_gt_stages": np.full(num_frames, np.nan),
|
||||
"target_key": f"{scheme}_targets",
|
||||
"num_stages": num_stages,
|
||||
"temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
|
||||
"subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
|
||||
}
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
|
||||
|
||||
# Process frames one at a time to avoid memory buildup
|
||||
frame_indices = list(range(ep_start, ep_end, stride))
|
||||
if (ep_end - 1) not in frame_indices:
|
||||
frame_indices.append(ep_end - 1)
|
||||
frame_indices = sorted(set(frame_indices))
|
||||
|
||||
for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
|
||||
local_idx = frame_idx - ep_start
|
||||
sample = dataset[frame_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": frame_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
# Ground truth
|
||||
# In stride visualization mode, ground-truth plots can be misleading
|
||||
# (only sparse points are available), so we skip GT.
|
||||
if stride == 1 and sd["target_key"] in processed:
|
||||
gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
|
||||
sd["viz_gt_stages"][local_idx] = int(gt_target)
|
||||
sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
|
||||
gt_target,
|
||||
num_stages=sd["num_stages"],
|
||||
temporal_proportions=sd["temporal_props"],
|
||||
subtask_names=sd["subtask_names"],
|
||||
)
|
||||
|
||||
# Predictions
|
||||
reward, stage_probs = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
return_stages=True,
|
||||
head_mode=scheme,
|
||||
)
|
||||
|
||||
# Handle both tensor and numpy outputs
|
||||
if isinstance(reward, torch.Tensor):
|
||||
reward = reward.cpu().numpy()
|
||||
stage_probs = stage_probs.cpu().numpy()
|
||||
|
||||
if reward.ndim == 2:
|
||||
sd["viz_progress"][local_idx] = reward[0, target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
|
||||
else:
|
||||
sd["viz_progress"][local_idx] = reward[target_idx]
|
||||
sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
|
||||
|
||||
# Clear GPU memory after each frame
|
||||
del processed, video_features, text_features
|
||||
if state_features is not None:
|
||||
del state_features
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Interpolate predictions back to per-frame arrays for smooth visualization.
|
||||
if stride > 1:
|
||||
all_local = np.arange(num_frames)
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
|
||||
valid = np.isfinite(sd["viz_progress"])
|
||||
valid_idx = np.where(valid)[0]
|
||||
if valid_idx.size >= 1:
|
||||
sd["viz_progress"] = interpolate_progress(
|
||||
valid_idx, sd["viz_progress"][valid_idx], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
|
||||
for s in range(sd["num_stages"]):
|
||||
stage_interp[:, s] = interpolate_progress(
|
||||
valid_idx, sd["viz_stages"][valid_idx, s], all_local
|
||||
)
|
||||
|
||||
stage_interp = np.clip(stage_interp, 0.0, 1.0)
|
||||
row_sums = stage_interp.sum(axis=1, keepdims=True)
|
||||
nz = row_sums.squeeze(-1) > 0
|
||||
stage_interp[nz] = stage_interp[nz] / row_sums[nz]
|
||||
sd["viz_stages"] = stage_interp
|
||||
else:
|
||||
# No valid points: keep NaNs/zeros; visualization will be empty.
|
||||
sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
|
||||
|
||||
# Generate visualization for each head
|
||||
ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
|
||||
for scheme in schemes_to_viz:
|
||||
sd = scheme_data[scheme]
|
||||
stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
|
||||
viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
|
||||
|
||||
visualize_episode(
|
||||
frames=np.array(ordered_viz_frames),
|
||||
progress_preds=sd["viz_progress"],
|
||||
stage_preds=sd["viz_stages"],
|
||||
title=f"{task} (Episode {episode_idx})",
|
||||
output_path=viz_path,
|
||||
stage_labels=stage_labels,
|
||||
gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
|
||||
gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
|
||||
)
|
||||
|
||||
# Clear memory between episodes
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logging.info(f"Visualizations saved to: {output_dir.absolute()}")
|
||||
|
||||
|
||||
def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
|
||||
"""Generate all frame indices, ordered by offset for cache-friendly access.
|
||||
|
||||
Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
|
||||
This groups frames that share similar temporal windows together.
|
||||
"""
|
||||
num_frames = ep_end - ep_start
|
||||
indices = []
|
||||
for offset in range(frame_gap):
|
||||
for frame_rel in range(offset, num_frames, frame_gap):
|
||||
indices.append(ep_start + frame_rel)
|
||||
return indices
|
||||
|
||||
|
||||
def interpolate_progress(
|
||||
computed_indices: np.ndarray,
|
||||
computed_values: np.ndarray,
|
||||
all_indices: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
|
||||
computed_indices = np.asarray(computed_indices)
|
||||
computed_values = np.asarray(computed_values)
|
||||
all_indices = np.asarray(all_indices)
|
||||
|
||||
mask = np.isfinite(computed_values)
|
||||
if mask.sum() == 0:
|
||||
return np.full(all_indices.shape, np.nan, dtype=np.float32)
|
||||
if mask.sum() == 1:
|
||||
return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
|
||||
|
||||
out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def compute_sarm_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
head_mode: str = "sparse",
|
||||
device: str = "cuda",
|
||||
num_visualizations: int = 5,
|
||||
output_dir: str = "./sarm_viz",
|
||||
stride: int = 1,
|
||||
):
|
||||
"""
|
||||
Compute SARM progress predictions for all frames in a dataset.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HuggingFace dataset repo ID or local path
|
||||
reward_model_path: Path to pretrained SARM model
|
||||
output_path: Path to save results. If None, saves to dataset's cache directory
|
||||
head_mode: SARM head to use ("sparse", "dense", or "both")
|
||||
device: Device to use for inference
|
||||
num_visualizations: Number of episodes to visualize (0 to skip)
|
||||
output_dir: Directory to save visualizations
|
||||
stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
|
||||
"""
|
||||
dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
|
||||
|
||||
# Set preprocessor to eval mode to disable augmentations
|
||||
if hasattr(preprocess, "eval"):
|
||||
preprocess.eval()
|
||||
for step in preprocess.steps:
|
||||
if hasattr(step, "eval"):
|
||||
step.eval()
|
||||
|
||||
image_key = reward_model.config.image_key
|
||||
state_key = reward_model.config.state_key
|
||||
frame_gap = reward_model.config.frame_gap
|
||||
num_episodes = dataset.num_episodes
|
||||
total_frames = dataset.num_frames
|
||||
logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
|
||||
|
||||
# Determine which heads to compute
|
||||
dual_mode = reward_model.config.uses_dual_heads
|
||||
compute_sparse = head_mode in ("sparse", "both") or not dual_mode
|
||||
compute_dense = head_mode in ("dense", "both") and dual_mode
|
||||
|
||||
# Storage arrays
|
||||
all_indices = []
|
||||
all_episode_indices = []
|
||||
all_frame_indices = []
|
||||
all_progress_sparse = [] if compute_sparse else None
|
||||
all_progress_dense = [] if compute_dense else None
|
||||
|
||||
if stride > 1:
|
||||
logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
|
||||
|
||||
# Process all episodes
|
||||
for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
|
||||
# Get task description
|
||||
task = dataset[ep_start].get("task", "perform the task")
|
||||
|
||||
# Generate frames to compute (with stride applied)
|
||||
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||
if stride > 1:
|
||||
# Only compute every stride-th frame (relative to episode start)
|
||||
compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
|
||||
# Always include last frame for better interpolation at episode end
|
||||
last_frame = ep_end - 1
|
||||
if last_frame not in compute_indices:
|
||||
compute_indices.append(last_frame)
|
||||
compute_indices = sorted(set(compute_indices))
|
||||
else:
|
||||
compute_indices = all_ep_indices
|
||||
|
||||
center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
|
||||
|
||||
# Dictionary to collect results
|
||||
frame_results = {}
|
||||
|
||||
for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
|
||||
try:
|
||||
sample = dataset[query_idx]
|
||||
|
||||
batch = {
|
||||
image_key: sample[image_key],
|
||||
"task": task,
|
||||
"index": query_idx,
|
||||
"episode_index": episode_idx,
|
||||
}
|
||||
if state_key in sample:
|
||||
batch[state_key] = sample[state_key]
|
||||
|
||||
with torch.no_grad():
|
||||
processed = preprocess(batch)
|
||||
video_features = processed["video_features"].to(device)
|
||||
text_features = processed["text_features"].to(device)
|
||||
state_features = processed.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(device)
|
||||
lengths = processed.get("lengths")
|
||||
|
||||
sparse_val = np.nan
|
||||
dense_val = np.nan
|
||||
|
||||
# Compute sparse prediction for center frame
|
||||
if compute_sparse:
|
||||
sparse_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="sparse",
|
||||
)
|
||||
sparse_val = float(
|
||||
sparse_progress[0, center_idx]
|
||||
if sparse_progress.ndim == 2
|
||||
else sparse_progress[center_idx]
|
||||
)
|
||||
|
||||
# Compute dense prediction for center frame
|
||||
if compute_dense:
|
||||
dense_progress = reward_model.calculate_rewards(
|
||||
text_embeddings=text_features,
|
||||
video_embeddings=video_features,
|
||||
state_features=state_features,
|
||||
lengths=lengths,
|
||||
return_all_frames=True,
|
||||
head_mode="dense",
|
||||
)
|
||||
dense_val = float(
|
||||
dense_progress[0, center_idx]
|
||||
if dense_progress.ndim == 2
|
||||
else dense_progress[center_idx]
|
||||
)
|
||||
|
||||
frame_results[query_idx] = (sparse_val, dense_val)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to process frame {query_idx}: {e}")
|
||||
|
||||
# Interpolate to get values for all frames
|
||||
computed_indices = np.array(sorted(frame_results.keys()))
|
||||
computed_sparse = (
|
||||
np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
|
||||
)
|
||||
computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
|
||||
|
||||
# All frame indices for this episode
|
||||
all_frame_idx_array = np.arange(ep_start, ep_end)
|
||||
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
# Interpolate progress values
|
||||
if compute_sparse:
|
||||
interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
|
||||
if compute_dense:
|
||||
interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
|
||||
else:
|
||||
# No interpolation needed
|
||||
interp_sparse = computed_sparse if compute_sparse else None
|
||||
interp_dense = computed_dense if compute_dense else None
|
||||
|
||||
# Store results for all frames
|
||||
for i, frame_idx in enumerate(all_frame_idx_array):
|
||||
local_idx = frame_idx - ep_start
|
||||
all_indices.append(frame_idx)
|
||||
all_episode_indices.append(episode_idx)
|
||||
all_frame_indices.append(local_idx)
|
||||
if compute_sparse:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_sparse.append(float(interp_sparse[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_sparse.append(frame_results[frame_idx][0])
|
||||
else:
|
||||
all_progress_sparse.append(np.nan)
|
||||
if compute_dense:
|
||||
if stride > 1 and len(computed_indices) > 1:
|
||||
all_progress_dense.append(float(interp_dense[i]))
|
||||
elif frame_idx in frame_results:
|
||||
all_progress_dense.append(frame_results[frame_idx][1])
|
||||
else:
|
||||
all_progress_dense.append(np.nan)
|
||||
|
||||
# Create output table
|
||||
table_data = {
|
||||
"index": np.array(all_indices, dtype=np.int64),
|
||||
"episode_index": np.array(all_episode_indices, dtype=np.int64),
|
||||
"frame_index": np.array(all_frame_indices, dtype=np.int64),
|
||||
}
|
||||
if compute_sparse:
|
||||
table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
|
||||
if compute_dense:
|
||||
table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
|
||||
|
||||
# Sort by index
|
||||
df = pa.table(table_data).to_pandas()
|
||||
df = df.sort_values("index").reset_index(drop=True)
|
||||
final_table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
|
||||
# Add metadata with reward model path
|
||||
metadata = {b"reward_model_path": reward_model_path.encode()}
|
||||
final_table = final_table.replace_schema_metadata(metadata)
|
||||
|
||||
# Determine output path
|
||||
output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
|
||||
|
||||
# Save
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(final_table, output_path)
|
||||
logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
|
||||
|
||||
# Print statistics
|
||||
if "progress_sparse" in df.columns:
|
||||
valid = df["progress_sparse"].dropna()
|
||||
logging.info(
|
||||
f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
if "progress_dense" in df.columns:
|
||||
valid = df["progress_dense"].dropna()
|
||||
logging.info(
|
||||
f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||
)
|
||||
|
||||
# Visualize episodes after processing
|
||||
if num_visualizations > 0:
|
||||
viz_episodes = list(range(min(num_visualizations, num_episodes)))
|
||||
logging.info(f"Generating {len(viz_episodes)} visualizations...")
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=head_mode,
|
||||
output_dir=Path(output_dir),
|
||||
stride=stride,
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
--num-visualizations 10
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repo ID or local path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output path for parquet. If not set, saves to dataset's cache directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--head-mode",
|
||||
type=str,
|
||||
default="sparse",
|
||||
choices=["sparse", "dense", "both"],
|
||||
help="SARM head to use (default: sparse)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use (default: cuda)",
|
||||
)
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--visualize-only",
|
||||
action="store_true",
|
||||
help="Only visualize SARM predictions (no RA-BC computation)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-visualizations",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of episodes to visualize (default: 5, set to 0 to skip)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./sarm_viz",
|
||||
help="Output directory for visualizations (default: ./sarm_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload progress file to the dataset repo on HuggingFace Hub",
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stride",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
# Try to get reward_model_path from parquet metadata if not provided
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
# Load dataset to find parquet path
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)"
|
||||
)
|
||||
|
||||
# Handle visualize-only mode
|
||||
if args.visualize_only:
|
||||
dataset, reward_model, preprocess = load_sarm_resources(
|
||||
args.dataset_repo_id, reward_model_path, args.device
|
||||
)
|
||||
logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
|
||||
viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
|
||||
visualize_sarm_predictions(
|
||||
dataset=dataset,
|
||||
reward_model=reward_model,
|
||||
preprocess=preprocess,
|
||||
episode_indices=viz_episodes,
|
||||
head_mode=args.head_mode,
|
||||
output_dir=Path(args.output_dir),
|
||||
stride=args.stride,
|
||||
)
|
||||
print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
|
||||
return
|
||||
|
||||
# Full RABC computation (compute_sarm_progress loads model/dataset itself)
|
||||
output_path = compute_sarm_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
head_mode=args.head_mode,
|
||||
device=args.device,
|
||||
num_visualizations=args.num_visualizations,
|
||||
output_dir=args.output_dir,
|
||||
stride=args.stride,
|
||||
)
|
||||
|
||||
print(f"\nSARM progress values saved to: {output_path}")
|
||||
|
||||
# Upload to Hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = "sarm_progress.parquet"
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse # or dense")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,248 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
|
||||
1. single_stage (default): No annotations needed. Uses the episode's task description
|
||||
as a single stage covering the entire episode.
|
||||
|
||||
2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
|
||||
single sparse "task" stage covering the full episode. The dense head learns detailed
|
||||
subtask progression while sparse provides overall task completion.
|
||||
|
||||
3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
|
||||
annotations from VLM. Both heads are trained on their respective annotations.
|
||||
|
||||
The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
|
||||
are loaded/generated during model initialization.
|
||||
"""
|
||||
|
||||
annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
|
||||
n_obs_steps: int = 8 # Number of observation history steps
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
|
||||
|
||||
# Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
|
||||
# During training with rewind: [obs_frames] + [rewind_frames]
|
||||
# During inference: [obs_frames] only
|
||||
|
||||
# Architecture params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
max_state_dim: int = 32
|
||||
drop_n_last_frames: int = 1
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
rewind_probability: float = 0.8
|
||||
language_perturbation_probability: float = 0.2
|
||||
|
||||
# Sparse annotations (high-level stages)
|
||||
num_sparse_stages: int = 1
|
||||
sparse_subtask_names: list | None = None
|
||||
sparse_temporal_proportions: list | None = None
|
||||
|
||||
# Dense annotations (fine-grained stages)
|
||||
num_dense_stages: int | None = None
|
||||
dense_subtask_names: list | None = None
|
||||
dense_temporal_proportions: list | None = None
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features (updated in __post_init__)
|
||||
output_features: dict = field(
|
||||
default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
}
|
||||
)
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
)
|
||||
|
||||
if self.annotation_mode == "single_stage":
|
||||
# Use task description as stage name, full episode as one stage
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
self.num_dense_stages = None
|
||||
self.dense_subtask_names = None
|
||||
self.dense_temporal_proportions = None
|
||||
|
||||
elif self.annotation_mode == "dense_only":
|
||||
self.num_sparse_stages = 1
|
||||
self.sparse_subtask_names = ["task"]
|
||||
self.sparse_temporal_proportions = [1.0]
|
||||
|
||||
self.input_features = {}
|
||||
self.output_features = {}
|
||||
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
|
||||
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,),
|
||||
type=FeatureType.STATE,
|
||||
)
|
||||
|
||||
# Update output features based on annotation_mode
|
||||
if self.annotation_mode in ["dense_only", "dual"]:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
dense_stages = self.num_dense_stages or self.num_sparse_stages
|
||||
self.output_features["dense_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["dense_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
else:
|
||||
self.output_features["sparse_stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["sparse_progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
if self.max_rewind_steps >= self.n_obs_steps:
|
||||
raise ValueError(
|
||||
f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
|
||||
)
|
||||
if self.num_sparse_stages < 1:
|
||||
raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
|
||||
if (
|
||||
self.annotation_mode in ["dense_only", "dual"]
|
||||
and self.num_dense_stages is not None
|
||||
and self.num_dense_stages < 2
|
||||
):
|
||||
raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
return AdamWConfig(
|
||||
lr=5e-5,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=5e-5,
|
||||
decay_lr=5e-6,
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def uses_dual_heads(self) -> bool:
|
||||
"""Whether the model uses dual heads (dense_only or dual annotation modes)."""
|
||||
return self.annotation_mode in ["dense_only", "dual"]
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Total number of frames in sequence.
|
||||
|
||||
For training: 1 + n_obs_steps + max_rewind_steps
|
||||
The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
|
||||
"""
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
@property
|
||||
def max_length(self) -> int:
|
||||
return self.num_frames
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Bidirectional frame sampling centered on target frame.
|
||||
|
||||
Example with n_obs_steps=8, gap=30:
|
||||
Before: [-120, -90, -60, -30] (4 frames)
|
||||
Current: [0] (1 frame)
|
||||
After: [30, 60, 90, 120] (4 frames)
|
||||
Total: 9 frames
|
||||
"""
|
||||
half_steps = self.n_obs_steps // 2
|
||||
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
|
||||
# Rewind placeholders
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
|
||||
return obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""SARM is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,793 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||
|
||||
Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
- StageTransformer: Predicts stage classification (sparse/dense)
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
"""
|
||||
Stage classification transformer for SARM.
|
||||
|
||||
Predicts which stage/subtask the current frame belongs to.
|
||||
Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
|
||||
Output: stage logits (B, T, num_classes)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
num_classes_sparse: int = 4,
|
||||
num_classes_dense: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
|
||||
|
||||
# Positional bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion MLP
|
||||
# Fuses (num_cameras + 2) streams: cameras + lang + state
|
||||
fused_in = d_model * (num_cameras + 2)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, num_classes_sparse),
|
||||
"dense": nn.Linear(d_model, num_classes_dense),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
|
||||
Accepts lang_emb of shape:
|
||||
- (B, text_emb_dim) -> broadcast across time
|
||||
- (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
|
||||
|
||||
Returns: (B, 1, T, D)
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
|
||||
lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
return lang_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for stage classification.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E) for dense
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Stage logits (B, T, num_classes)
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
|
||||
# Concatenate streams
|
||||
# cameras + lang + state -> (B, N+2, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens for Transformer
|
||||
x_tokens = x.view(B, (N + 2) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
|
||||
fused = self.fusion_backbone(h) # (B, T, D)
|
||||
|
||||
# Scheme-specific logits
|
||||
logits = self.heads[scheme](fused) # (B, T, num_classes)
|
||||
return logits
|
||||
|
||||
|
||||
class SubtaskTransformer(nn.Module):
|
||||
"""
|
||||
Subtask progress regression transformer for SARM.
|
||||
|
||||
Predicts within-stage normalized progress (tau) conditioned on stage prior.
|
||||
The stage prior is a one-hot encoding passed from StageTransformer predictions.
|
||||
|
||||
Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
|
||||
Output: tau predictions (B, T) in [0, 1]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 512,
|
||||
vis_emb_dim: int = 512,
|
||||
text_emb_dim: int = 512,
|
||||
state_dim: int = 32,
|
||||
n_layers: int = 6,
|
||||
n_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
num_cameras: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_cameras = num_cameras
|
||||
|
||||
# Projections
|
||||
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||
self.state_proj = nn.Linear(state_dim, d_model)
|
||||
|
||||
# Encoder
|
||||
enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||
self.transformer = nn.TransformerEncoder(enc, n_layers)
|
||||
|
||||
# Learned bias on first visual frame
|
||||
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||
|
||||
# Shared fusion backbone
|
||||
# Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
|
||||
fused_in = d_model * (num_cameras + 3)
|
||||
self.fusion_backbone = nn.Sequential(
|
||||
nn.LayerNorm(fused_in),
|
||||
nn.Linear(fused_in, d_model),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Scheme-specific regression heads
|
||||
self.heads = nn.ModuleDict(
|
||||
{
|
||||
"sparse": nn.Linear(d_model, 1),
|
||||
"dense": nn.Linear(d_model, 1),
|
||||
}
|
||||
)
|
||||
|
||||
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||
"""
|
||||
Prepare language embeddings for fusion.
|
||||
"""
|
||||
if lang_emb.dim() == 3:
|
||||
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1)
|
||||
else:
|
||||
# (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
|
||||
return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||
|
||||
def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Deterministic projection of one-hot stage to d_model by pad/truncate.
|
||||
|
||||
Args:
|
||||
stage_prior: One-hot stage embedding (B, 1, T, C)
|
||||
|
||||
Returns:
|
||||
Projected stage embedding (B, 1, T, d_model)
|
||||
"""
|
||||
B, one, T, C = stage_prior.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
if D == C:
|
||||
return stage_prior
|
||||
elif D > C:
|
||||
pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
|
||||
return torch.cat([stage_prior, pad], dim=-1)
|
||||
else:
|
||||
return stage_prior[..., :D]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||
stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
|
||||
scheme: str = "sparse", # "sparse" or "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for subtask progress regression.
|
||||
|
||||
Args:
|
||||
img_seq: Image embeddings (B, N, T, vis_emb_dim)
|
||||
lang_emb: Language embeddings (B, E) or (B, T, E)
|
||||
state: State features (B, T, state_dim)
|
||||
lengths: Valid sequence lengths (B,) for masking
|
||||
stage_prior: One-hot stage prior (B, 1, T, num_classes)
|
||||
scheme: "sparse" or "dense" for head selection
|
||||
|
||||
Returns:
|
||||
Tau predictions (B, T) in [0, 1] via sigmoid
|
||||
"""
|
||||
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||
|
||||
B, N, T, _ = img_seq.shape # noqa: N806
|
||||
D = self.d_model # noqa: N806
|
||||
device = img_seq.device
|
||||
|
||||
# Project inputs
|
||||
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||
stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
|
||||
|
||||
# Concatenate all streams
|
||||
# cameras + lang + state + stage_emb -> (B, N+3, T, D)
|
||||
x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
|
||||
|
||||
# Add positional bias to first visual frame
|
||||
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||
|
||||
# Flatten to tokens
|
||||
x_tokens = x.view(B, (N + 3) * T, D)
|
||||
L = x_tokens.size(1) # noqa: N806
|
||||
|
||||
# Create padding mask
|
||||
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
|
||||
mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
|
||||
|
||||
# Create causal mask
|
||||
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
# Encode
|
||||
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||
|
||||
# Reshape and fuse
|
||||
h = h.view(B, N + 3, T, D)
|
||||
h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
|
||||
fused = self.fusion_backbone(h_flat) # (B, T, D)
|
||||
|
||||
# Scheme-specific regression head -> sigmoid
|
||||
r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
|
||||
return r
|
||||
|
||||
|
||||
def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Generate one-hot stage embeddings from targets.
|
||||
|
||||
Args:
|
||||
num_classes: Number of stage classes
|
||||
targets: Target values (B, T) where integer part is stage index
|
||||
|
||||
Returns:
|
||||
One-hot stage embedding (B, 1, T, num_classes)
|
||||
"""
|
||||
# Integer part of float targets -> [0, C-1]
|
||||
idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
|
||||
C = num_classes # noqa: N806
|
||||
# Identity-lookup one-hot
|
||||
stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
|
||||
stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Uses two separate transformer models:
|
||||
- StageTransformer: Classifies which stage/subtask
|
||||
- SubtaskTransformer: Predicts within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
|
||||
"""
|
||||
|
||||
name = "sarm"
|
||||
config_class = SARMConfig
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(
|
||||
config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
# Load temporal proportions based on annotation_mode
|
||||
if config.annotation_mode == "single_stage":
|
||||
logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
|
||||
elif dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
# Create two separate models
|
||||
self.stage_model = StageTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1, # Single camera for now
|
||||
num_classes_sparse=config.num_sparse_stages,
|
||||
num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
|
||||
)
|
||||
|
||||
self.subtask_model = SubtaskTransformer(
|
||||
d_model=config.hidden_dim,
|
||||
vis_emb_dim=config.image_dim,
|
||||
text_emb_dim=config.text_dim,
|
||||
state_dim=config.max_state_dim,
|
||||
n_layers=config.num_layers,
|
||||
n_heads=config.num_heads,
|
||||
dropout=config.dropout,
|
||||
num_cameras=1,
|
||||
)
|
||||
|
||||
self.stage_model.to(self.device)
|
||||
self.subtask_model.to(self.device)
|
||||
|
||||
# GT/predicted stage ratio for teacher forcing
|
||||
self.gt_stage_ratio = 0.75
|
||||
|
||||
if config.uses_dual_heads:
|
||||
logging.info(
|
||||
f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
|
||||
f"{config.num_dense_stages} dense stages"
|
||||
)
|
||||
else:
|
||||
logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
|
||||
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
|
||||
"""Load temporal proportions from a JSON file (preserving order)."""
|
||||
if not path.exists():
|
||||
raise ValueError(
|
||||
f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
|
||||
f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
|
||||
)
|
||||
with open(path) as f:
|
||||
proportions_dict = json.load(f)
|
||||
names = list(proportions_dict.keys())
|
||||
logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
|
||||
logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
|
||||
return names, [proportions_dict[name] for name in names]
|
||||
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""Load temporal proportions based on annotation_mode."""
|
||||
meta_path = dataset_meta.root / "meta"
|
||||
|
||||
if self.config.annotation_mode == "dual":
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_sparse.json", "sparse"
|
||||
)
|
||||
(
|
||||
self.config.num_sparse_stages,
|
||||
self.config.sparse_subtask_names,
|
||||
self.config.sparse_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
|
||||
if self.config.annotation_mode in ["dense_only", "dual"]:
|
||||
names, props = self._load_proportions_from_json(
|
||||
meta_path / "temporal_proportions_dense.json", "dense"
|
||||
)
|
||||
(
|
||||
self.config.num_dense_stages,
|
||||
self.config.dense_subtask_names,
|
||||
self.config.dense_temporal_proportions,
|
||||
) = len(names), names, props
|
||||
if self.config.annotation_mode == "dense_only":
|
||||
logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.stage_model.to(device)
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: np.ndarray | torch.Tensor,
|
||||
video_embeddings: np.ndarray | torch.Tensor,
|
||||
state_features: np.ndarray | torch.Tensor | None = None,
|
||||
lengths: np.ndarray | torch.Tensor | None = None,
|
||||
return_all_frames: bool = False,
|
||||
return_stages: bool = False,
|
||||
return_confidence: bool = False,
|
||||
head_mode: str | None = "sparse",
|
||||
frame_index: int | None = None,
|
||||
) -> np.ndarray | tuple:
|
||||
"""
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
This is the canonical method for SARM reward computation, used for:
|
||||
- Inference/visualization
|
||||
- RA-BC weight computation
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
lengths: Valid sequence lengths (batch_size,)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
return_stages: If True, also return stage predictions
|
||||
return_confidence: If True, also return stage confidence
|
||||
head_mode: Which head to use ("sparse" or "dense")
|
||||
frame_index: Index of the target frame to extract (default: n_obs_steps).
|
||||
|
||||
Returns:
|
||||
Rewards and optionally stage probs/confidence.
|
||||
"""
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
if state_features is not None:
|
||||
state_features = state_features.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
batch_size = video_embeddings.shape[0]
|
||||
seq_len = video_embeddings.shape[1]
|
||||
|
||||
scheme = head_mode
|
||||
|
||||
# Default lengths if not provided
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
|
||||
elif isinstance(lengths, np.ndarray):
|
||||
lengths = torch.tensor(lengths, dtype=torch.int32)
|
||||
|
||||
# Reshape video to (B, N, T, D) for multi-camera format
|
||||
# Currently single camera: (B, T, D) -> (B, 1, T, D)
|
||||
img_seq = video_embeddings.unsqueeze(1).to(self.device)
|
||||
lang_emb = text_embeddings.to(self.device)
|
||||
state = (
|
||||
state_features.to(self.device)
|
||||
if state_features is not None
|
||||
else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
)
|
||||
lens = lengths.to(self.device)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
state = pad_state_to_max_dim(state, self.config.max_state_dim)
|
||||
|
||||
# Get num_classes for this scheme
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Run stage model
|
||||
stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
|
||||
stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
|
||||
stage_idx = stage_probs.argmax(dim=-1) # (B, T)
|
||||
stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
|
||||
|
||||
# Create one-hot stage prior
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model
|
||||
tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute final reward: stage + tau
|
||||
raw_reward = stage_idx.float() + tau_pred # (B, T)
|
||||
|
||||
# Normalize to [0, 1] using temporal proportions for proper weighting
|
||||
if scheme == "sparse":
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.sparse_temporal_proportions,
|
||||
subtask_names=self.config.sparse_subtask_names,
|
||||
)
|
||||
else:
|
||||
normalized_reward = normalize_stage_tau(
|
||||
raw_reward,
|
||||
num_stages=num_classes,
|
||||
temporal_proportions=self.config.dense_temporal_proportions,
|
||||
subtask_names=self.config.dense_subtask_names,
|
||||
)
|
||||
|
||||
# Default frame index is n_obs_steps (last observation frame)
|
||||
if frame_index is None:
|
||||
frame_index = self.config.n_obs_steps
|
||||
|
||||
# Prepare outputs (batch mode or no smoothing)
|
||||
if return_all_frames:
|
||||
rewards = normalized_reward.cpu().numpy()
|
||||
else:
|
||||
rewards = normalized_reward[:, frame_index].cpu().numpy()
|
||||
|
||||
if single_sample:
|
||||
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||
|
||||
outputs = [rewards]
|
||||
if return_stages:
|
||||
probs = stage_probs.cpu().numpy()
|
||||
if single_sample:
|
||||
probs = probs[0]
|
||||
outputs.append(probs)
|
||||
if return_confidence:
|
||||
conf = stage_conf.cpu().numpy()
|
||||
if single_sample:
|
||||
conf = conf[0]
|
||||
outputs.append(conf)
|
||||
|
||||
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for both models."""
|
||||
super().train(mode)
|
||||
self.stage_model.train(mode)
|
||||
self.subtask_model.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode for both models."""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Override to return trainable parameters from both models."""
|
||||
from itertools import chain
|
||||
|
||||
return chain(self.stage_model.parameters(), self.subtask_model.parameters())
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Override to return optimizer parameters from both models."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||
state: torch.Tensor, # (B, T, state_dim)
|
||||
lengths: torch.Tensor, # (B,)
|
||||
targets: torch.Tensor, # (B, T) - format: stage.tau
|
||||
scheme: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Single training step for one annotation scheme.
|
||||
|
||||
Implements 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
img_emb: Image embeddings (B, N, T, D)
|
||||
lang_emb: Language embeddings
|
||||
state: State features
|
||||
lengths: Valid sequence lengths
|
||||
targets: Target values where floor=stage, remainder=tau
|
||||
scheme: "sparse" or "dense"
|
||||
|
||||
Returns:
|
||||
Dict with stage_loss, subtask_loss, total_loss
|
||||
"""
|
||||
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||
|
||||
# Ground truth: stage (integer) and tau (fractional)
|
||||
# Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
|
||||
# where targets may exceed expected range (e.g., frames between subtasks)
|
||||
gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
|
||||
gt_tau = torch.remainder(targets, 1.0) # (B, T)
|
||||
|
||||
# Run stage model
|
||||
stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
|
||||
|
||||
# 75%/25% GT/predicted stage conditioning
|
||||
if random.random() < self.gt_stage_ratio:
|
||||
# Mode 1: Use ground truth stage -> one-hot
|
||||
stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
|
||||
else:
|
||||
# Mode 2: Use predicted stage argmax -> one-hot
|
||||
stage_idx = stage_pred.argmax(dim=-1) # (B, T)
|
||||
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||
|
||||
# Run subtask model with stage prior
|
||||
tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
|
||||
|
||||
# Compute losses
|
||||
stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
|
||||
subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
|
||||
|
||||
return {
|
||||
"stage_loss": stage_loss,
|
||||
"subtask_loss": subtask_loss,
|
||||
"total_loss": stage_loss + subtask_loss,
|
||||
}
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses stage+tau target format where:
|
||||
- Integer part = stage index
|
||||
- Fractional part = within-stage progress (tau)
|
||||
|
||||
Training uses 75%/25% GT/predicted stage conditioning.
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 512) or (B, T, 512) text features
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'lengths': (B,) valid sequence lengths
|
||||
- 'sparse_targets': (B, T) sparse targets (stage.tau format)
|
||||
- 'dense_targets': (B, T) dense targets (optional, for dual mode)
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get("observation", batch)
|
||||
|
||||
# Extract features
|
||||
video_features = observation["video_features"].to(self.device)
|
||||
text_features = observation["text_features"].to(self.device)
|
||||
state_features = observation.get("state_features")
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
seq_len = video_features.shape[1]
|
||||
|
||||
# Get lengths (default to full sequence)
|
||||
lengths = observation.get("lengths")
|
||||
if lengths is None:
|
||||
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
|
||||
else:
|
||||
lengths = lengths.to(self.device)
|
||||
|
||||
# Reshape video to (B, N, T, D) - single camera
|
||||
img_emb = video_features.unsqueeze(1)
|
||||
|
||||
# Pad state to max_state_dim
|
||||
if state_features is None:
|
||||
state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||
else:
|
||||
state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
|
||||
|
||||
output_dict = {}
|
||||
total_loss = torch.tensor(0.0, device=self.device)
|
||||
|
||||
# Sparse training (always)
|
||||
sparse_targets = observation.get("sparse_targets")
|
||||
if sparse_targets is None:
|
||||
# Try legacy format
|
||||
sparse_targets = observation.get("targets")
|
||||
if sparse_targets is None:
|
||||
raise ValueError("sparse_targets (or targets) is required for SARM training")
|
||||
sparse_targets = sparse_targets.to(self.device)
|
||||
|
||||
sparse_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
|
||||
)
|
||||
output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
|
||||
output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
|
||||
total_loss = total_loss + sparse_result["total_loss"]
|
||||
|
||||
# Dense training (if dual mode)
|
||||
if self.config.uses_dual_heads:
|
||||
dense_targets = observation.get("dense_targets")
|
||||
if dense_targets is not None:
|
||||
dense_targets = dense_targets.to(self.device)
|
||||
dense_result = self._train_step(
|
||||
img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
|
||||
)
|
||||
output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
|
||||
output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
|
||||
total_loss = total_loss + dense_result["total_loss"]
|
||||
|
||||
output_dict["total_loss"] = total_loss.item()
|
||||
return total_loss, output_dict
|
||||
|
||||
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute cross-entropy loss for stage classification."""
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
# Clamp target stage indices to valid range [0, num_stages-1]
|
||||
target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
|
||||
return F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
@@ -0,0 +1,518 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""SARM Processor for encoding images/text and generating stage+tau targets."""
|
||||
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from faker import Faker
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
from_tensor_to_numpy,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
image_key: str | None = None,
|
||||
dataset_meta=None,
|
||||
dataset_stats: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.annotation_mode = config.annotation_mode
|
||||
|
||||
# Helper to create temporal proportions dict
|
||||
def make_props_dict(names, props):
|
||||
return dict(zip(names, props, strict=True)) if names and props else None
|
||||
|
||||
# Sparse annotations (always needed)
|
||||
self.sparse_temporal_proportions = make_props_dict(
|
||||
config.sparse_subtask_names, config.sparse_temporal_proportions
|
||||
)
|
||||
self.sparse_subtask_names = config.sparse_subtask_names
|
||||
|
||||
# Dense annotations (only for dual mode)
|
||||
self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
|
||||
self.dense_temporal_proportions = (
|
||||
make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
|
||||
if config.uses_dual_heads
|
||||
else None
|
||||
)
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
|
||||
self.fake = Faker()
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _generate_perturbed_task(self) -> str:
|
||||
"""Generate a random perturbed task string for language perturbation."""
|
||||
num_words = random.randint(1, 5)
|
||||
verb = random.choice(self.verbs)
|
||||
phrase = " ".join([verb] + self.fake.words(nb=num_words))
|
||||
return phrase
|
||||
|
||||
def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
|
||||
"""Get global subtask names and temporal proportions for an annotation type."""
|
||||
if annotation_type == "dense":
|
||||
return self.dense_subtask_names, self.dense_temporal_proportions
|
||||
return self.sparse_subtask_names, self.sparse_temporal_proportions
|
||||
|
||||
def _load_episode_annotations(
|
||||
self,
|
||||
ep_idx: int,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
global_names: list[str],
|
||||
) -> tuple[list | None, list | None, list | None]:
|
||||
"""Load subtask annotations for an episode from DataFrame."""
|
||||
# Single-stage mode: (linear progress 0→1)
|
||||
if episodes_df is None or len(global_names) == 1:
|
||||
return None, None, None
|
||||
|
||||
# Resolve column name with fallback
|
||||
def col(suffix):
|
||||
prefixed = f"{annotation_type}_{suffix}"
|
||||
return prefixed if prefixed in episodes_df.columns else suffix
|
||||
|
||||
col_names = col("subtask_names")
|
||||
if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, col_names]
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None, None
|
||||
|
||||
return (
|
||||
subtask_names,
|
||||
episodes_df.loc[ep_idx, col("subtask_start_frames")],
|
||||
episodes_df.loc[ep_idx, col("subtask_end_frames")],
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Encode images, text, and normalize states in the transition.
|
||||
|
||||
Implements SARM training data preparation:
|
||||
- Applies language perturbation (20% probability)
|
||||
- Applies rewind augmentation (80% probability)
|
||||
- Generates stage+tau targets for all frames
|
||||
- Outputs lengths tensor for valid sequence masking
|
||||
"""
|
||||
new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
frame_index = comp_data.get("index")
|
||||
episode_index = comp_data.get("episode_index")
|
||||
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
image = observation.get(self.image_key)
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
|
||||
# If 4D (T, C, H, W) from delta_timestamps, add batch dim
|
||||
# If 3D (C, H, W) single frame, add batch and time dims
|
||||
if image.ndim == 4:
|
||||
image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
|
||||
elif image.ndim == 3:
|
||||
image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
|
||||
|
||||
# Rewind augmentation
|
||||
rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
|
||||
apply_rewind = self.training and random.random() < self.config.rewind_probability
|
||||
|
||||
if apply_rewind and self.dataset_meta is not None:
|
||||
for b_idx, (ep_idx, frame_idx) in enumerate(
|
||||
zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
|
||||
):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
|
||||
rewind_step, _ = apply_rewind_augmentation(
|
||||
frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
|
||||
)
|
||||
rewind_steps[b_idx] = rewind_step
|
||||
|
||||
# Compute valid lengths: n_obs_frames + rewind_steps
|
||||
lengths = n_obs_frames + rewind_steps # (B,)
|
||||
|
||||
# Apply rewind masking to images
|
||||
# For frames beyond valid length, we mask with zeros (or copy last valid frame)
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < total_frames:
|
||||
image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
# Encode images with CLIP
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation["video_features"] = video_features
|
||||
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
|
||||
elif state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
|
||||
|
||||
# Apply same rewind masking to state
|
||||
for b_idx in range(batch_size):
|
||||
valid_len = lengths[b_idx].item()
|
||||
if valid_len < state_tensor.shape[1]:
|
||||
state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||
|
||||
observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
task = comp_data.get("task")
|
||||
if isinstance(task, list):
|
||||
task = task[0] if task else ""
|
||||
|
||||
# Apply language perturbation during training (20% probability)
|
||||
# When perturbed, targets will be zeroed to train model to output low values for irrelevant text
|
||||
apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
|
||||
if apply_perturbation:
|
||||
task = self._generate_perturbed_task()
|
||||
|
||||
# Encode text with CLIP
|
||||
observation["text_features"] = self._encode_text_clip(task, batch_size)
|
||||
|
||||
# Store lengths for model
|
||||
observation["lengths"] = lengths
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
sparse_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
|
||||
)
|
||||
observation["sparse_targets"] = sparse_targets
|
||||
|
||||
# Generate dense targets (for dual mode)
|
||||
if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
|
||||
if apply_perturbation:
|
||||
# Zero targets when language is perturbed
|
||||
dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
else:
|
||||
dense_targets = self._compute_batch_targets(
|
||||
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
|
||||
)
|
||||
observation["dense_targets"] = dense_targets
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _compute_batch_targets(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
lengths: torch.Tensor,
|
||||
rewind_steps: torch.Tensor,
|
||||
episodes_df: pd.DataFrame | None,
|
||||
annotation_type: str,
|
||||
) -> torch.Tensor:
|
||||
"""Compute stage+tau targets for a batch of samples."""
|
||||
batch_size = len(frame_indices)
|
||||
n_obs_steps = self.config.n_obs_steps
|
||||
max_rewind_steps = self.config.max_rewind_steps
|
||||
total_frames = 1 + n_obs_steps + max_rewind_steps
|
||||
frame_gap = self.config.frame_gap
|
||||
|
||||
global_names, temporal_props = self._get_annotation_config(annotation_type)
|
||||
targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
ep_idx = int(episode_indices[b_idx])
|
||||
frame_idx = int(frame_indices[b_idx])
|
||||
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
|
||||
ep_idx, episodes_df, annotation_type, global_names
|
||||
)
|
||||
|
||||
# Compute observation frame indices
|
||||
obs_indices, _ = compute_absolute_indices(
|
||||
frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
|
||||
)
|
||||
obs_indices = obs_indices.tolist()
|
||||
|
||||
# Compute targets for observation frames
|
||||
for t_idx, abs_idx in enumerate(obs_indices):
|
||||
rel_frame = abs_idx - ep_start
|
||||
targets[b_idx, t_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
# Compute targets for rewind frames (if any)
|
||||
rewind_step = rewind_steps[b_idx].item()
|
||||
if rewind_step > 0:
|
||||
_, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps,
|
||||
max_rewind_steps,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=rewind_step,
|
||||
)
|
||||
|
||||
for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
|
||||
rel_frame = max(0, abs_idx - ep_start)
|
||||
targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
|
||||
rel_frame,
|
||||
ep_length,
|
||||
subtask_names,
|
||||
subtask_start_frames,
|
||||
subtask_end_frames,
|
||||
global_names,
|
||||
temporal_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
return targets
|
||||
|
||||
@property
|
||||
def training(self) -> bool:
|
||||
return getattr(self, "_training_mode", True)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode for augmentation decisions."""
|
||||
self._training_mode = mode
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode (disable augmentations)."""
|
||||
return self.train(False)
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
num_frames = images.shape[0]
|
||||
images_list = []
|
||||
for i in range(num_frames):
|
||||
img = images[i]
|
||||
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||
img = img.transpose(1, 2, 0)
|
||||
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
images_list.append(Image.fromarray(img))
|
||||
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||
batch_imgs = images_list[i : i + self.config.clip_batch_size]
|
||||
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_sarm_pre_post_processors(
|
||||
config: SARMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta=None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre-processor and post-processor pipelines for SARM."""
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def find_stage_and_tau(
|
||||
current_frame: int,
|
||||
episode_length: int,
|
||||
subtask_names: list | None,
|
||||
subtask_start_frames: list | None,
|
||||
subtask_end_frames: list | None,
|
||||
global_subtask_names: list,
|
||||
temporal_proportions: dict,
|
||||
return_combined: bool = False,
|
||||
) -> tuple[int, float] | float:
|
||||
"""Find stage and within-stage progress (tau) for a frame.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
episode_length: Total frames in episode
|
||||
subtask_names: Subtask names for this episode (None for single_stage)
|
||||
subtask_start_frames: Subtask start frames
|
||||
subtask_end_frames: Subtask end frames
|
||||
global_subtask_names: Global list of all subtask names
|
||||
temporal_proportions: Dict of temporal proportions
|
||||
return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
|
||||
|
||||
Returns:
|
||||
Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
|
||||
"""
|
||||
stage_idx, tau = 0, 0.0
|
||||
num_stages = len(global_subtask_names)
|
||||
|
||||
# Single-stage mode: linear progress from 0 to 1
|
||||
if num_stages == 1:
|
||||
tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
|
||||
elif subtask_names is None:
|
||||
pass # stage_idx=0, tau=0.0
|
||||
elif current_frame < subtask_start_frames[0]:
|
||||
pass # Before first subtask: stage_idx=0, tau=0.0
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
stage_idx, tau = num_stages - 1, 0.999 # After last subtask
|
||||
else:
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
|
||||
if start <= current_frame <= end:
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
|
||||
tau = compute_tau(current_frame, start, end)
|
||||
found = True
|
||||
break
|
||||
# Frame between subtasks - use previous subtask's end state
|
||||
if not found:
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
|
||||
tau = 1.0
|
||||
break
|
||||
|
||||
if return_combined:
|
||||
# Clamp to avoid overflow at end
|
||||
if stage_idx >= num_stages - 1 and tau >= 1.0:
|
||||
return num_stages - 1 + 0.999
|
||||
return stage_idx + tau
|
||||
return stage_idx, tau
|
||||
|
||||
|
||||
def compute_absolute_indices(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
ep_end: int,
|
||||
n_obs_steps: int,
|
||||
frame_gap: int = 30,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute absolute frame indices with clamping for bidirectional observation sequence.
|
||||
|
||||
Bidirectional sampling centered on target frame:
|
||||
- Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
|
||||
- Current: [0] (1 frame)
|
||||
- After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
|
||||
- Total: n_obs_steps + 1 frames
|
||||
|
||||
Out-of-bounds frames are clamped (duplicated from boundary).
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center frame of sequence)
|
||||
ep_start: Episode start index
|
||||
ep_end: Episode end index (exclusive)
|
||||
n_obs_steps: Number of observation steps (must be even for symmetric sampling)
|
||||
frame_gap: Gap between observation frames
|
||||
|
||||
Returns:
|
||||
Tuple of (indices, out_of_bounds_flags)
|
||||
"""
|
||||
half_steps = n_obs_steps // 2
|
||||
|
||||
# Bidirectional deltas: past + current + future
|
||||
past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
|
||||
delta_indices = past_deltas + [0] + future_deltas
|
||||
|
||||
frames = []
|
||||
out_of_bounds = []
|
||||
|
||||
for delta in delta_indices:
|
||||
target_idx = frame_idx + delta
|
||||
# Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
|
||||
clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
|
||||
frames.append(clamped_idx)
|
||||
# Flag as out-of-bounds if clamping occurred
|
||||
out_of_bounds.append(1 if target_idx != clamped_idx else 0)
|
||||
|
||||
return torch.tensor(frames), torch.tensor(out_of_bounds)
|
||||
|
||||
|
||||
def apply_rewind_augmentation(
|
||||
frame_idx: int,
|
||||
ep_start: int,
|
||||
n_obs_steps: int,
|
||||
max_rewind_steps: int,
|
||||
frame_gap: int = 30,
|
||||
rewind_step: int | None = None,
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Generate rewind frame indices for temporal augmentation.
|
||||
|
||||
Rewind simulates going backwards through previously seen frames,
|
||||
starting from before the earliest observation frame (for bidirectional sampling).
|
||||
Appends reversed frames after the observation sequence.
|
||||
|
||||
Args:
|
||||
frame_idx: Target frame index (center of bidirectional observation window)
|
||||
ep_start: Episode start index
|
||||
n_obs_steps: Number of observation steps
|
||||
max_rewind_steps: Maximum rewind steps
|
||||
frame_gap: Gap between frames
|
||||
rewind_step: If provided, use this exact rewind step (for deterministic behavior).
|
||||
If None, sample randomly.
|
||||
|
||||
Returns:
|
||||
Tuple of (rewind_step, rewind_indices)
|
||||
"""
|
||||
# For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
|
||||
half_steps = n_obs_steps // 2
|
||||
earliest_obs_frame = frame_idx - half_steps * frame_gap
|
||||
|
||||
# Required history: frames before earliest observation frame
|
||||
if earliest_obs_frame <= ep_start:
|
||||
return 0, [] # No history before observation window
|
||||
|
||||
# Max valid rewind steps based on available history before earliest obs frame
|
||||
available_history = earliest_obs_frame - ep_start
|
||||
max_valid_step = available_history // frame_gap
|
||||
max_rewind = min(max_rewind_steps, max(0, max_valid_step))
|
||||
|
||||
if max_rewind <= 0:
|
||||
return 0, []
|
||||
|
||||
# Sample rewind steps if not provided
|
||||
rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
|
||||
|
||||
if rewind_step == 0:
|
||||
return 0, []
|
||||
|
||||
# Generate rewind indices going backwards from earliest obs frame
|
||||
# rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
|
||||
rewind_indices = []
|
||||
for i in range(1, rewind_step + 1):
|
||||
idx = earliest_obs_frame - i * frame_gap
|
||||
idx = max(ep_start, idx) # Clamp to episode start
|
||||
rewind_indices.append(idx)
|
||||
|
||||
return rewind_step, rewind_indices
|
||||
|
||||
|
||||
def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
|
||||
"""Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
|
||||
duration = subtask_end - subtask_start
|
||||
if duration <= 0:
|
||||
return 1.0
|
||||
return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
|
||||
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode="constant", value=0)
|
||||
|
||||
|
||||
def temporal_proportions_to_breakpoints(
|
||||
temporal_proportions: dict[str, float] | list[float] | None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> list[float] | None:
|
||||
"""Convert temporal proportions to cumulative breakpoints for normalization."""
|
||||
if temporal_proportions is None:
|
||||
return None
|
||||
|
||||
if isinstance(temporal_proportions, dict):
|
||||
if subtask_names is not None:
|
||||
proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
|
||||
else:
|
||||
proportions = list(temporal_proportions.values())
|
||||
else:
|
||||
proportions = list(temporal_proportions)
|
||||
|
||||
total = sum(proportions)
|
||||
if total > 0 and abs(total - 1.0) > 1e-6:
|
||||
proportions = [p / total for p in proportions]
|
||||
|
||||
breakpoints = [0.0]
|
||||
cumsum = 0.0
|
||||
for prop in proportions:
|
||||
cumsum += prop
|
||||
breakpoints.append(cumsum)
|
||||
breakpoints[-1] = 1.0
|
||||
|
||||
return breakpoints
|
||||
|
||||
|
||||
def normalize_stage_tau(
|
||||
x: float | torch.Tensor,
|
||||
num_stages: int | None = None,
|
||||
breakpoints: list[float] | None = None,
|
||||
temporal_proportions: dict[str, float] | list[float] | None = None,
|
||||
subtask_names: list[str] | None = None,
|
||||
) -> float | torch.Tensor:
|
||||
"""
|
||||
Normalize stage+tau reward to [0, 1] with custom breakpoints.
|
||||
|
||||
Maps stage index + within-stage tau to normalized progress [0, 1].
|
||||
The breakpoints are designed to give appropriate weight to each stage
|
||||
based on their importance in the task (using temporal proportions).
|
||||
|
||||
Priority: breakpoints > temporal_proportions > linear fallback
|
||||
|
||||
Args:
|
||||
x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1)
|
||||
num_stages: Number of stages (required if breakpoints/proportions not provided)
|
||||
breakpoints: Optional custom breakpoints list of length num_stages + 1.
|
||||
temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
|
||||
subtask_names: Optional ordered list of subtask names (for dict proportions)
|
||||
|
||||
Returns:
|
||||
Normalized progress value ∈ [0, 1]
|
||||
"""
|
||||
if breakpoints is not None:
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif temporal_proportions is not None:
|
||||
breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
|
||||
num_stages = len(breakpoints) - 1
|
||||
elif num_stages is not None:
|
||||
breakpoints = [i / num_stages for i in range(num_stages + 1)]
|
||||
else:
|
||||
raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
result = torch.zeros_like(x)
|
||||
for i in range(num_stages):
|
||||
mask = (x >= i) & (x < i + 1)
|
||||
tau_in_stage = x - i
|
||||
result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
|
||||
result[x >= num_stages] = 1.0
|
||||
return result.clamp(0.0, 1.0)
|
||||
else:
|
||||
if x < 0:
|
||||
return 0.0
|
||||
if x >= num_stages:
|
||||
return 1.0
|
||||
stage = int(x)
|
||||
tau = x - stage
|
||||
return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
|
||||
@@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
def forward(
|
||||
self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
|
||||
) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss.
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations and actions.
|
||||
noise: Optional noise tensor for flow matching.
|
||||
time: Optional time tensor for flow matching.
|
||||
reduction: How to reduce the loss. Options:
|
||||
- "mean": Return scalar mean loss (default, backward compatible)
|
||||
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||
"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
@@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
|
||||
@@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -231,11 +231,20 @@ def validate_visual_features_consistency(
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Validation passes if EITHER:
|
||||
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
|
||||
# Accept if either direction is a subset
|
||||
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||
|
||||
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
@@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
config_class = XVLAConfig
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig):
|
||||
def __init__(self, config: XVLAConfig, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
florence_config = config.get_florence_config()
|
||||
|
||||
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -62,6 +62,7 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -78,6 +79,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -87,9 +89,30 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
@@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
cfg.validate()
|
||||
|
||||
# Only log on main process
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
@@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
@@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
@@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
@@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values
|
||||
chunk_size: Number of frames ahead for computing progress delta
|
||||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||
device: Device to return tensors on
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
progress_path: str | Path,
|
||||
chunk_size: int = 50,
|
||||
head_mode: str = "sparse",
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device = None,
|
||||
):
|
||||
self.progress_path = Path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.head_mode = head_mode
|
||||
self.kappa = kappa
|
||||
self.epsilon = epsilon
|
||||
self.fallback_weight = fallback_weight
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Determine progress column name
|
||||
self.progress_column = f"progress_{head_mode}"
|
||||
|
||||
# Load progress values
|
||||
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||||
self.df = pd.read_parquet(self.progress_path)
|
||||
|
||||
# Check if the requested head mode column exists
|
||||
if self.progress_column not in self.df.columns:
|
||||
available = [c for c in self.df.columns if c.startswith("progress")]
|
||||
raise ValueError(
|
||||
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||||
)
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup = {}
|
||||
self.episode_lookup = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
progress = row[self.progress_column]
|
||||
episode_idx = int(row["episode_index"])
|
||||
|
||||
if not np.isnan(progress):
|
||||
self.progress_lookup[global_idx] = float(progress)
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
"start": int(ep_df["index"].min()),
|
||||
"end": int(ep_df["index"].max()) + 1,
|
||||
}
|
||||
|
||||
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||||
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||||
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self):
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
for global_idx, progress in self.progress_lookup.items():
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
continue
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
continue
|
||||
|
||||
future_idx = global_idx + self.chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is not None:
|
||||
delta = future_progress - progress
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
self.delta_std = self.epsilon
|
||||
logging.warning("No valid progress deltas found, using default stats")
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute RA-BC weights for a batch.
|
||||
|
||||
For each sample:
|
||||
1. Get progress at current frame
|
||||
2. Get progress at frame + chunk_size (within same episode)
|
||||
3. Compute delta = future_progress - current_progress
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
indices = indices.cpu().numpy().tolist()
|
||||
elif isinstance(indices, np.ndarray):
|
||||
indices = indices.tolist()
|
||||
|
||||
# Compute deltas and weights for each sample
|
||||
deltas = []
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"raw_mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
|
||||
return weights, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
current_progress = self.progress_lookup.get(global_idx)
|
||||
if current_progress is None:
|
||||
return np.nan
|
||||
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
return np.nan
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
return np.nan
|
||||
|
||||
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress instead
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is None:
|
||||
return np.nan
|
||||
|
||||
return future_progress - current_progress
|
||||
|
||||
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute RA-BC weights from progress deltas.
|
||||
|
||||
Following paper Eq. 8-9:
|
||||
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
# Compute soft weights using global statistics
|
||||
lower_bound = self.delta_mean - 2 * self.delta_std
|
||||
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||||
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||||
|
||||
# Apply paper's Eq. 9
|
||||
weights = np.zeros_like(deltas, dtype=np.float32)
|
||||
|
||||
# High quality: ri > kappa → weight = 1
|
||||
high_quality_mask = deltas > self.kappa
|
||||
weights[high_quality_mask] = 1.0
|
||||
|
||||
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||||
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||||
weights[moderate_mask] = soft_weights[moderate_mask]
|
||||
|
||||
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||||
# Invalid (NaN): use fallback weight
|
||||
weights[~valid_mask] = self.fallback_weight
|
||||
|
||||
return weights
|
||||
|
||||
def _get_batch_size(self, batch: dict) -> int:
|
||||
"""Determine batch size from batch."""
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return val.shape[0]
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get statistics."""
|
||||
return {
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
"delta_mean": self.delta_mean,
|
||||
"delta_std": self.delta_std,
|
||||
"kappa": self.kappa,
|
||||
}
|
||||
@@ -0,0 +1,694 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("faker")
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor.core import TransitionKey
|
||||
|
||||
|
||||
class MockDatasetMeta:
|
||||
"""Mock dataset metadata for testing processor."""
|
||||
|
||||
def __init__(self, episodes: list[dict]):
|
||||
self._episodes = episodes
|
||||
|
||||
@property
|
||||
def episodes(self):
|
||||
"""Return episodes as a mock object with to_pandas() method."""
|
||||
mock = MagicMock()
|
||||
mock.__len__ = lambda s: len(self._episodes)
|
||||
mock.__getitem__ = lambda s, idx: self._episodes[idx]
|
||||
mock.to_pandas = lambda: pd.DataFrame(self._episodes)
|
||||
return mock
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock SARMConfig for testing processor methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_obs_steps: int = 8,
|
||||
max_rewind_steps: int = 4,
|
||||
frame_gap: int = 30,
|
||||
sparse_subtask_names: list = None,
|
||||
sparse_temporal_proportions: list = None,
|
||||
dense_subtask_names: list = None,
|
||||
dense_temporal_proportions: list = None,
|
||||
image_key: str = "observation.images.top",
|
||||
state_key: str = "observation.state",
|
||||
max_state_dim: int = 32,
|
||||
device: str = None,
|
||||
rewind_probability: float = 0.8,
|
||||
language_perturbation_probability: float = 0.2,
|
||||
annotation_mode: str = "dual",
|
||||
clip_batch_size: int = 64,
|
||||
text_dim: int = 512,
|
||||
):
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.max_rewind_steps = max_rewind_steps
|
||||
self.frame_gap = frame_gap
|
||||
self.sparse_subtask_names = sparse_subtask_names or ["task"]
|
||||
self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
|
||||
self.dense_subtask_names = dense_subtask_names
|
||||
self.dense_temporal_proportions = dense_temporal_proportions
|
||||
self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
|
||||
self.image_key = image_key
|
||||
self.state_key = state_key
|
||||
self.max_state_dim = max_state_dim
|
||||
self.device = device
|
||||
self.rewind_probability = rewind_probability
|
||||
self.language_perturbation_probability = language_perturbation_probability
|
||||
self.annotation_mode = annotation_mode
|
||||
self.clip_batch_size = clip_batch_size
|
||||
self.text_dim = text_dim
|
||||
|
||||
# Compute observation delta indices (same as config: bidirectional)
|
||||
half_steps = self.n_obs_steps // 2
|
||||
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||
obs_deltas = past_deltas + [0] + future_deltas
|
||||
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||
self.observation_delta_indices = obs_deltas + rewind_deltas
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||
|
||||
|
||||
class TestSARMEncodingProcessorStepEndToEnd:
|
||||
"""End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clip_model(self):
|
||||
"""Mock CLIP model to avoid loading real weights."""
|
||||
with (
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||
):
|
||||
# Mock the CLIP model - return embeddings based on input batch size
|
||||
mock_model = MagicMock()
|
||||
|
||||
def get_image_features_side_effect(**kwargs):
|
||||
pixel_values = kwargs.get("pixel_values")
|
||||
batch_size = pixel_values.shape[0] if pixel_values is not None else 1
|
||||
return torch.randn(batch_size, 512)
|
||||
|
||||
mock_model.get_image_features.side_effect = get_image_features_side_effect
|
||||
mock_model.get_text_features.return_value = torch.randn(1, 512)
|
||||
mock_model.to.return_value = mock_model
|
||||
mock_model_cls.from_pretrained.return_value = mock_model
|
||||
|
||||
# Mock the CLIP processor - return tensors based on input images
|
||||
mock_processor = MagicMock()
|
||||
|
||||
def processor_side_effect(images=None, **kwargs):
|
||||
num_images = len(images) if images is not None else 1
|
||||
return {
|
||||
"pixel_values": torch.randn(num_images, 3, 224, 224),
|
||||
}
|
||||
|
||||
mock_processor.side_effect = processor_side_effect
|
||||
# Mock tokenizer for text encoding
|
||||
mock_processor.tokenizer.return_value = {
|
||||
"input_ids": torch.ones(1, 77, dtype=torch.long),
|
||||
"attention_mask": torch.ones(1, 77, dtype=torch.long),
|
||||
}
|
||||
mock_processor_cls.from_pretrained.return_value = mock_processor
|
||||
|
||||
yield mock_model, mock_processor
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_mocks(self, mock_clip_model):
|
||||
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Dual mode config with both sparse and dense annotations
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0, # Disable for deterministic test
|
||||
language_perturbation_probability=0.0, # Disable for deterministic test
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp", "lift"],
|
||||
sparse_temporal_proportions=[0.3, 0.4, 0.3],
|
||||
dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Create mock dataset metadata with one episode of 300 frames
|
||||
# Include annotation columns for dual mode
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "pick up the cube",
|
||||
"sparse_subtask_names": ["reach", "grasp", "lift"],
|
||||
"sparse_subtask_start_frames": [0, 90, 210],
|
||||
"sparse_subtask_end_frames": [90, 210, 300],
|
||||
"dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
)
|
||||
processor.train(True) # Use train() method, not direct assignment
|
||||
|
||||
return processor, config
|
||||
|
||||
def test_call_with_single_frame_batch(self, processor_with_mocks):
|
||||
"""Test processor __call__ with a single-frame batch."""
|
||||
processor, config = processor_with_mocks
|
||||
|
||||
# Create dummy input transition
|
||||
batch_size = 1
|
||||
num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
|
||||
|
||||
# Image: (T, C, H, W) format as expected by processor
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
|
||||
# State: (T, D) format
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150, # Middle of episode
|
||||
"episode_index": 0,
|
||||
"task": "pick up the cube",
|
||||
},
|
||||
}
|
||||
|
||||
# Run processor
|
||||
result = processor(transition)
|
||||
|
||||
# Verify output structure
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check video features exist and have correct shape
|
||||
assert "video_features" in obs
|
||||
video_features = obs["video_features"]
|
||||
assert video_features.shape[0] == batch_size
|
||||
assert video_features.shape[1] == num_frames
|
||||
assert video_features.shape[2] == 512 # CLIP embedding dim
|
||||
|
||||
# Check state features exist and have correct shape
|
||||
assert "state_features" in obs
|
||||
state_features = obs["state_features"]
|
||||
assert state_features.shape[0] == batch_size
|
||||
assert state_features.shape[1] == num_frames
|
||||
assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
|
||||
|
||||
# Check text features exist and have correct shape
|
||||
assert "text_features" in obs
|
||||
text_features = obs["text_features"]
|
||||
assert text_features.shape[0] == batch_size
|
||||
assert text_features.shape[1] == 512 # CLIP embedding dim
|
||||
|
||||
# Check lengths tensor
|
||||
assert "lengths" in obs
|
||||
lengths = obs["lengths"]
|
||||
assert lengths.shape[0] == batch_size
|
||||
assert lengths.dtype == torch.int32
|
||||
|
||||
# Check sparse_targets exist
|
||||
assert "sparse_targets" in obs
|
||||
sparse_targets = obs["sparse_targets"]
|
||||
assert sparse_targets.shape == (batch_size, num_frames)
|
||||
# All targets should be in [0, max_stages] range (stage.tau format)
|
||||
assert (sparse_targets >= 0).all()
|
||||
|
||||
# Check dense_targets exist (for dual mode)
|
||||
assert "dense_targets" in obs
|
||||
dense_targets = obs["dense_targets"]
|
||||
assert dense_targets.shape == (batch_size, num_frames)
|
||||
assert (dense_targets >= 0).all()
|
||||
|
||||
def test_call_with_batched_input(self, mock_clip_model):
|
||||
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["reach", "grasp"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["step1", "step2", "step3"],
|
||||
dense_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
)
|
||||
|
||||
# Two episodes with different lengths, each with sparse+dense annotations
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "task A",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [0, 66, 133],
|
||||
"dense_subtask_end_frames": [66, 133, 200],
|
||||
},
|
||||
{
|
||||
"dataset_from_index": 200,
|
||||
"dataset_to_index": 500,
|
||||
"task": "task B",
|
||||
"sparse_subtask_names": ["reach", "grasp"],
|
||||
"sparse_subtask_start_frames": [200, 350],
|
||||
"sparse_subtask_end_frames": [350, 500],
|
||||
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||
"dense_subtask_start_frames": [200, 300, 400],
|
||||
"dense_subtask_end_frames": [300, 400, 500],
|
||||
},
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
batch_size = 2
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Image: (B, T, C, H, W) format
|
||||
dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": np.array([100, 350]), # One frame from each episode
|
||||
"episode_index": np.array([0, 1]),
|
||||
"task": ["task A", "task B"],
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Verify batch dimension is preserved for all outputs
|
||||
assert obs["video_features"].shape[0] == batch_size
|
||||
assert obs["state_features"].shape[0] == batch_size
|
||||
assert obs["lengths"].shape[0] == batch_size
|
||||
assert obs["sparse_targets"].shape[0] == batch_size
|
||||
assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
|
||||
|
||||
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=30,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["phase1", "phase2"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["a", "b", "c", "d"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 300,
|
||||
"task": "test task",
|
||||
"sparse_subtask_names": ["phase1", "phase2"],
|
||||
"sparse_subtask_start_frames": [0, 150],
|
||||
"sparse_subtask_end_frames": [150, 300],
|
||||
"dense_subtask_names": ["a", "b", "c", "d"],
|
||||
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at early, middle, and late points in episode
|
||||
frame_indices = [30, 150, 270]
|
||||
sparse_center_targets = []
|
||||
dense_center_targets = []
|
||||
|
||||
for frame_idx in frame_indices:
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": frame_idx,
|
||||
"episode_index": 0,
|
||||
"task": "test task",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
# Get target at center frame (index 4 in 9-frame observation window)
|
||||
sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
|
||||
dense_center_targets.append(obs["dense_targets"][0, 4].item())
|
||||
|
||||
# Both sparse and dense targets should increase with frame index
|
||||
assert sparse_center_targets[0] < sparse_center_targets[2], (
|
||||
f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
|
||||
)
|
||||
assert dense_center_targets[0] < dense_center_targets[2], (
|
||||
f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
|
||||
)
|
||||
|
||||
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10, # Smaller gap for easier calculation
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2", "d3", "d4"],
|
||||
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||
)
|
||||
|
||||
# Episode: frames 0-99, sparse stages at [0-49], [50-99]
|
||||
# Dense stages at [0-24], [25-49], [50-74], [75-99]
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 50],
|
||||
"sparse_subtask_end_frames": [50, 100],
|
||||
"dense_subtask_names": ["d1", "d2", "d3", "d4"],
|
||||
"dense_subtask_start_frames": [0, 25, 50, 75],
|
||||
"dense_subtask_end_frames": [25, 50, 75, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (center of episode)
|
||||
# With frame_gap=10, n_obs_steps=8:
|
||||
# obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0] # (13,)
|
||||
dense_targets = obs["dense_targets"][0] # (13,)
|
||||
|
||||
# First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
|
||||
# Check that obs frames have non-zero targets
|
||||
obs_sparse = sparse_targets[:9]
|
||||
obs_dense = dense_targets[:9]
|
||||
|
||||
# Verify targets are monotonically increasing for observation frames
|
||||
for i in range(1, 9):
|
||||
assert obs_sparse[i] >= obs_sparse[i - 1], (
|
||||
f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
|
||||
)
|
||||
assert obs_dense[i] >= obs_dense[i - 1], (
|
||||
f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
|
||||
)
|
||||
|
||||
# Rewind slots should be zero when rewind is disabled
|
||||
rewind_targets = sparse_targets[9:]
|
||||
assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
|
||||
|
||||
# Check stage transitions: frame 50 is at boundary of sparse stage A->B
|
||||
# Center frame (index 4) corresponds to actual frame 50
|
||||
center_sparse = obs_sparse[4].item()
|
||||
# At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
|
||||
assert 0.9 <= center_sparse <= 1.1, (
|
||||
f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
|
||||
)
|
||||
|
||||
def test_rewind_augmentation_applied(self, mock_clip_model):
|
||||
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||
import random
|
||||
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=1.0, # Always apply rewind
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["A", "B"],
|
||||
sparse_temporal_proportions=[0.5, 0.5],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 200,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["A", "B"],
|
||||
"sparse_subtask_start_frames": [0, 100],
|
||||
"sparse_subtask_end_frames": [100, 200],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 100],
|
||||
"dense_subtask_end_frames": [100, 200],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames # 13
|
||||
|
||||
# Test at frame 150 (center of bidirectional window)
|
||||
# With n_obs_steps=8, half_steps=4, frame_gap=10:
|
||||
# - Earliest obs frame = 150 - 4*10 = 110
|
||||
# - Rewind can go back from 110 to frames like 100, 90, 80, 70
|
||||
# - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 150,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
# Seed random for reproducibility
|
||||
random.seed(42)
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
lengths = obs["lengths"][0].item()
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
|
||||
# With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
|
||||
assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
|
||||
assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
|
||||
|
||||
# Rewind targets should be non-zero for frames within valid length
|
||||
n_obs_frames = 9
|
||||
rewind_count = lengths - n_obs_frames
|
||||
|
||||
if rewind_count > 0:
|
||||
# Check that rewind frames have targets
|
||||
rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
|
||||
# Rewind frames are from BEFORE the earliest obs frame (110)
|
||||
# These frames (100, 90, 80, 70) are earlier in the episode
|
||||
earliest_obs_target = sparse_targets[0].item() # Frame 110
|
||||
|
||||
# Rewind targets should be less than earliest obs (they're from earlier frames)
|
||||
for i, rt in enumerate(rewind_targets):
|
||||
assert rt.item() < earliest_obs_target, (
|
||||
f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
|
||||
)
|
||||
|
||||
# Rewind targets should be decreasing (going further back in time)
|
||||
for i in range(1, len(rewind_targets)):
|
||||
assert rewind_targets[i] <= rewind_targets[i - 1], (
|
||||
f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
|
||||
)
|
||||
|
||||
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
||||
|
||||
config = MockConfig(
|
||||
n_obs_steps=8,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=10,
|
||||
rewind_probability=0.0,
|
||||
language_perturbation_probability=0.0,
|
||||
annotation_mode="dual",
|
||||
sparse_subtask_names=["s1", "s2", "s3"],
|
||||
sparse_temporal_proportions=[0.33, 0.34, 0.33],
|
||||
dense_subtask_names=["d1", "d2"],
|
||||
dense_temporal_proportions=[0.5, 0.5],
|
||||
)
|
||||
|
||||
# 3 sparse stages: [0-33), [33-66), [66-99]
|
||||
# 2 dense stages: [0-50), [50-100)
|
||||
episodes = [
|
||||
{
|
||||
"dataset_from_index": 0,
|
||||
"dataset_to_index": 100,
|
||||
"task": "test",
|
||||
"sparse_subtask_names": ["s1", "s2", "s3"],
|
||||
"sparse_subtask_start_frames": [0, 33, 66],
|
||||
"sparse_subtask_end_frames": [33, 66, 100],
|
||||
"dense_subtask_names": ["d1", "d2"],
|
||||
"dense_subtask_start_frames": [0, 50],
|
||||
"dense_subtask_end_frames": [50, 100],
|
||||
}
|
||||
]
|
||||
dataset_meta = MockDatasetMeta(episodes)
|
||||
|
||||
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||
processor.train(True)
|
||||
|
||||
num_frames = config.num_frames
|
||||
|
||||
# Test at frame 50 (middle of episode)
|
||||
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
config.image_key: dummy_image,
|
||||
config.state_key: dummy_state,
|
||||
},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"index": 50,
|
||||
"episode_index": 0,
|
||||
"task": "test",
|
||||
},
|
||||
}
|
||||
|
||||
result = processor(transition)
|
||||
obs = result[TransitionKey.OBSERVATION]
|
||||
sparse_targets = obs["sparse_targets"][0]
|
||||
dense_targets = obs["dense_targets"][0]
|
||||
|
||||
# Manually compute expected targets for observation frames
|
||||
# With frame_gap=10, n_obs_steps=8, center at 50:
|
||||
# obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||
|
||||
sparse_names = ["s1", "s2", "s3"]
|
||||
sparse_starts = [0, 33, 66]
|
||||
sparse_ends = [33, 66, 100]
|
||||
sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
|
||||
|
||||
dense_names = ["d1", "d2"]
|
||||
dense_starts = [0, 50]
|
||||
dense_ends = [50, 100]
|
||||
dense_props = {"d1": 0.5, "d2": 0.5}
|
||||
|
||||
for i, frame in enumerate(expected_obs_frames):
|
||||
expected_sparse = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
sparse_names,
|
||||
sparse_starts,
|
||||
sparse_ends,
|
||||
sparse_names,
|
||||
sparse_props,
|
||||
return_combined=True,
|
||||
)
|
||||
expected_dense = find_stage_and_tau(
|
||||
frame,
|
||||
100,
|
||||
dense_names,
|
||||
dense_starts,
|
||||
dense_ends,
|
||||
dense_names,
|
||||
dense_props,
|
||||
return_combined=True,
|
||||
)
|
||||
|
||||
actual_sparse = sparse_targets[i].item()
|
||||
actual_dense = dense_targets[i].item()
|
||||
|
||||
assert abs(actual_sparse - expected_sparse) < 0.01, (
|
||||
f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
|
||||
)
|
||||
assert abs(actual_dense - expected_dense) < 0.01, (
|
||||
f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
|
||||
)
|
||||
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.data_processing.sarm_annotations.subtask_annotation import (
|
||||
Subtask,
|
||||
SubtaskAnnotation,
|
||||
Timestamp,
|
||||
compute_temporal_proportions,
|
||||
)
|
||||
|
||||
|
||||
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||
return SubtaskAnnotation(
|
||||
subtasks=[
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(
|
||||
start=f"{start // 60:02d}:{start % 60:02d}", end=f"{end // 60:02d}:{end % 60:02d}"
|
||||
),
|
||||
)
|
||||
for name, start, end in subtasks
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestComputeTemporalProportions:
|
||||
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||
|
||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
Key insight: This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of absolute length.
|
||||
"""
|
||||
|
||||
def test_basic_two_trajectories_equal_proportions(self):
|
||||
"""Test with two trajectories that have equal proportions."""
|
||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||
annotations = {
|
||||
0: make_annotation([("subtask1", 0, 50), ("subtask2", 50, 100)]),
|
||||
1: make_annotation([("subtask1", 0, 100), ("subtask2", 100, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Both should be 0.5
|
||||
assert abs(result["subtask1"] - 0.5) < 1e-6
|
||||
assert abs(result["subtask2"] - 0.5) < 1e-6
|
||||
|
||||
def test_paper_example_different_from_avg_durations(self):
|
||||
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||
|
||||
This is the key test showing the difference between:
|
||||
- Paper formula: average of (L_i,k / T_i)
|
||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||
"""
|
||||
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||
annotations = {
|
||||
0: make_annotation([("subtask1", 0, 80), ("subtask2", 80, 100)]),
|
||||
1: make_annotation([("subtask1", 0, 40), ("subtask2", 40, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Paper formula:
|
||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
|
||||
assert abs(result["subtask1"] - 0.5) < 1e-6
|
||||
assert abs(result["subtask2"] - 0.5) < 1e-6
|
||||
|
||||
def test_single_trajectory(self):
|
||||
"""Test with a single trajectory."""
|
||||
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||
annotations = {
|
||||
0: make_annotation([("reach", 0, 30), ("grasp", 30, 50), ("lift", 50, 100)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
assert abs(result["reach"] - 0.3) < 1e-6
|
||||
assert abs(result["grasp"] - 0.2) < 1e-6
|
||||
assert abs(result["lift"] - 0.5) < 1e-6
|
||||
|
||||
def test_sum_to_one(self):
|
||||
"""Test that proportions always sum to 1."""
|
||||
# Three episodes with varying proportions
|
||||
annotations = {
|
||||
0: make_annotation([("a", 0, 10), ("b", 10, 50), ("c", 50, 100)]), # 0.1, 0.4, 0.5
|
||||
1: make_annotation([("a", 0, 20), ("b", 20, 70), ("c", 70, 100)]), # 0.2, 0.5, 0.3
|
||||
2: make_annotation([("a", 0, 30), ("b", 30, 90), ("c", 90, 100)]), # 0.3, 0.6, 0.1
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
total = sum(result.values())
|
||||
assert abs(total - 1.0) < 1e-6
|
||||
|
||||
def test_empty_annotations_returns_empty(self):
|
||||
"""Test that empty annotations returns empty dict."""
|
||||
result = compute_temporal_proportions({})
|
||||
assert result == {}
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions across subtasks."""
|
||||
# Each subtask takes 25% of each episode
|
||||
annotations = {
|
||||
0: make_annotation([("a", 0, 25), ("b", 25, 50), ("c", 50, 75), ("d", 75, 100)]),
|
||||
1: make_annotation([("a", 0, 50), ("b", 50, 100), ("c", 100, 150), ("d", 150, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
for name in ["a", "b", "c", "d"]:
|
||||
assert abs(result[name] - 0.25) < 1e-6
|
||||
@@ -0,0 +1,615 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
compute_tau,
|
||||
find_stage_and_tau,
|
||||
normalize_stage_tau,
|
||||
temporal_proportions_to_breakpoints,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressLabelsWithModes:
|
||||
"""End-to-end tests for progress label generation in different modes."""
|
||||
|
||||
def test_sparse_mode_single_stage(self):
|
||||
"""Sparse mode with single stage should give linear progress."""
|
||||
episode_length = 300
|
||||
global_names = ["task"]
|
||||
proportions = {"task": 1.0}
|
||||
|
||||
# Test at various frames
|
||||
for frame in [0, 100, 200, 299]:
|
||||
stage, tau = find_stage_and_tau(
|
||||
frame, episode_length, None, None, None, global_names, proportions
|
||||
)
|
||||
|
||||
expected_tau = frame / (episode_length - 1)
|
||||
assert stage == 0
|
||||
assert abs(tau - expected_tau) < 1e-5
|
||||
|
||||
def test_sparse_mode_multi_stage(self):
|
||||
"""Sparse mode with multiple stages."""
|
||||
global_names = ["reach", "grasp", "lift", "place"]
|
||||
proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
|
||||
|
||||
subtask_names = ["reach", "grasp", "lift", "place"]
|
||||
subtask_starts = [0, 60, 120, 210]
|
||||
subtask_ends = [59, 119, 209, 299]
|
||||
|
||||
# Check stages are correctly identified
|
||||
stage_at_30, _ = find_stage_and_tau(
|
||||
30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_30 == 0
|
||||
|
||||
stage_at_90, _ = find_stage_and_tau(
|
||||
90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_90 == 1
|
||||
|
||||
stage_at_150, _ = find_stage_and_tau(
|
||||
150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage_at_150 == 2
|
||||
|
||||
def test_dense_mode_more_stages(self):
|
||||
"""Dense mode should work with more fine-grained stages."""
|
||||
global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||
proportions = dict.fromkeys(global_names, 1 / 8)
|
||||
|
||||
subtask_names = global_names
|
||||
subtask_starts = [i * 50 for i in range(8)]
|
||||
subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
|
||||
|
||||
# Each stage should occupy 50 frames
|
||||
for stage_idx in range(8):
|
||||
mid_frame = stage_idx * 50 + 25
|
||||
stage, _ = find_stage_and_tau(
|
||||
mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == stage_idx
|
||||
|
||||
|
||||
class TestComputeAbsoluteIndices:
|
||||
"""Tests for compute_absolute_indices (bidirectional sampling)."""
|
||||
|
||||
def test_no_clamping_when_in_middle(self):
|
||||
"""When frame is in middle of episode, no clamping should occur."""
|
||||
frame_idx = 300
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# All should be valid (no out of bounds)
|
||||
assert out_of_bounds.sum() == 0
|
||||
|
||||
# Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
|
||||
half_steps = n_obs_steps // 2
|
||||
expected = (
|
||||
[frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
|
||||
+ [frame_idx]
|
||||
+ [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
|
||||
)
|
||||
assert indices.tolist() == expected
|
||||
|
||||
# Center frame (index 4) should be frame_idx
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_clamping_at_episode_start(self):
|
||||
"""Early frames should be clamped to episode start."""
|
||||
frame_idx = 50 # Not enough history for full past window
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# Some past frames should be clamped (out_of_bounds = 1)
|
||||
assert out_of_bounds.sum() > 0
|
||||
|
||||
# All indices should be >= ep_start
|
||||
assert (indices >= ep_start).all()
|
||||
|
||||
# Center index should be frame_idx
|
||||
half_steps = n_obs_steps // 2
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_clamping_at_episode_end(self):
|
||||
"""Late frames should be clamped to episode end."""
|
||||
frame_idx = 950 # Not enough future for full window
|
||||
ep_start = 0
|
||||
ep_end = 1000
|
||||
n_obs_steps = 8
|
||||
frame_gap = 30
|
||||
|
||||
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||
|
||||
# Some future frames should be clamped
|
||||
assert out_of_bounds.sum() > 0
|
||||
|
||||
# All indices should be < ep_end
|
||||
assert (indices < ep_end).all()
|
||||
|
||||
# Center index should be frame_idx
|
||||
half_steps = n_obs_steps // 2
|
||||
assert indices[half_steps] == frame_idx
|
||||
|
||||
def test_sequence_is_monotonic(self):
|
||||
"""Frame indices should be monotonically increasing."""
|
||||
for frame_idx in [50, 100, 300, 950]:
|
||||
indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
|
||||
|
||||
# Check monotonic (non-decreasing due to clamping)
|
||||
diffs = indices[1:] - indices[:-1]
|
||||
assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
"""Tests for compute_tau (within-subtask progress).
|
||||
|
||||
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_at_start(self):
|
||||
"""τ should be 0 at subtask start."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_at_end(self):
|
||||
"""τ should be 1 at subtask end."""
|
||||
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_at_middle(self):
|
||||
"""τ should be 0.5 at subtask midpoint."""
|
||||
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
|
||||
assert abs(tau - 0.5) < 1e-6
|
||||
|
||||
def test_quarter_progress(self):
|
||||
"""Test τ at 25% through subtask."""
|
||||
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
|
||||
assert abs(tau - 0.25) < 1e-6
|
||||
|
||||
def test_zero_duration_subtask(self):
|
||||
"""τ should be 1.0 for zero-duration subtask."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_clamps_below_zero(self):
|
||||
"""τ should be clamped to 0 if frame is before subtask."""
|
||||
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_clamps_above_one(self):
|
||||
"""τ should be clamped to 1 if frame is after subtask."""
|
||||
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_float_inputs(self):
|
||||
"""Test with float frame indices (from interpolation)."""
|
||||
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
|
||||
expected = (25.5 - 10.0) / (50.0 - 10.0)
|
||||
assert abs(tau - expected) < 1e-6
|
||||
|
||||
|
||||
class TestFindStageAndTau:
|
||||
"""Tests for find_stage_and_tau logic.
|
||||
|
||||
This function is the core of progress label computation. It determines
|
||||
which stage a frame belongs to and the within-stage progress (tau).
|
||||
"""
|
||||
|
||||
def test_single_stage_mode_linear_progress(self):
|
||||
"""Single-stage mode should give linear progress from 0 to 1."""
|
||||
episode_length = 100
|
||||
|
||||
# Frame 0 -> tau = 0
|
||||
stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 0.0) < 1e-6
|
||||
|
||||
# Frame 50 -> tau = 0.505 (50/99)
|
||||
stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 50 / 99) < 1e-6
|
||||
|
||||
# Frame 99 -> tau = 1.0
|
||||
stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||
assert stage == 0
|
||||
assert abs(tau - 1.0) < 1e-6
|
||||
|
||||
def test_multi_stage_within_subtask(self):
|
||||
"""Test finding stage when frame is within a subtask."""
|
||||
global_names = ["reach", "grasp", "lift"]
|
||||
proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
|
||||
|
||||
subtask_names = ["reach", "grasp", "lift"]
|
||||
subtask_starts = [0, 30, 50]
|
||||
subtask_ends = [29, 49, 99]
|
||||
|
||||
# Frame 15 in "reach" stage (index 0)
|
||||
stage, tau = find_stage_and_tau(
|
||||
15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert abs(tau - 15 / 29) < 1e-6
|
||||
|
||||
# Frame 40 in "grasp" stage (index 1)
|
||||
stage, tau = find_stage_and_tau(
|
||||
40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1
|
||||
# tau = (40 - 30) / (49 - 30) = 10/19
|
||||
assert abs(tau - 10 / 19) < 1e-6
|
||||
|
||||
# Frame 75 in "lift" stage (index 2)
|
||||
stage, tau = find_stage_and_tau(
|
||||
75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 2
|
||||
# tau = (75 - 50) / (99 - 50) = 25/49
|
||||
assert abs(tau - 25 / 49) < 1e-6
|
||||
|
||||
def test_frame_at_subtask_boundaries(self):
|
||||
"""Test frames exactly at subtask boundaries."""
|
||||
global_names = ["a", "b"]
|
||||
proportions = {"a": 0.5, "b": 0.5}
|
||||
|
||||
subtask_names = ["a", "b"]
|
||||
subtask_starts = [0, 50]
|
||||
subtask_ends = [49, 99]
|
||||
|
||||
# Frame at start of first subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert tau == 0.0
|
||||
|
||||
# Frame at end of first subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 0
|
||||
assert tau == 1.0
|
||||
|
||||
# Frame at start of second subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1
|
||||
assert tau == 0.0
|
||||
|
||||
def test_frame_after_last_subtask(self):
|
||||
"""Frames after last subtask should return last stage with high tau."""
|
||||
global_names = ["a", "b"]
|
||||
proportions = {"a": 0.5, "b": 0.5}
|
||||
|
||||
subtask_names = ["a", "b"]
|
||||
subtask_starts = [0, 30]
|
||||
subtask_ends = [29, 59]
|
||||
|
||||
# Frame 80 is after last subtask
|
||||
stage, tau = find_stage_and_tau(
|
||||
80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||
)
|
||||
assert stage == 1 # Last stage
|
||||
assert tau == 0.999 # Nearly complete
|
||||
|
||||
|
||||
class TestEndToEndProgressLabeling:
|
||||
"""End-to-end tests for progress label computation using normalize_stage_tau."""
|
||||
|
||||
def test_consistent_semantic_meaning(self):
|
||||
"""Test that same subtask completion maps to same progress across trajectories.
|
||||
|
||||
This is the key semantic property: "end of subtask 1" should always
|
||||
mean the same progress value regardless of trajectory speed.
|
||||
"""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
|
||||
tau_fast = compute_tau(30, 0, 30) # = 1.0
|
||||
y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
|
||||
|
||||
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
|
||||
tau_slow = compute_tau(90, 0, 90) # = 1.0
|
||||
y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
|
||||
|
||||
# Both should map to same progress (0.3 = end of subtask 1)
|
||||
assert abs(y_fast - y_slow) < 1e-6
|
||||
assert abs(y_fast - 0.3) < 1e-6
|
||||
|
||||
def test_monotonic_within_subtask(self):
|
||||
"""Test that progress is monotonically increasing within a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
|
||||
prev_y = -1
|
||||
for tau in np.linspace(0, 1, 11):
|
||||
y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
|
||||
assert y > prev_y or (tau == 0 and y == 0)
|
||||
prev_y = y
|
||||
|
||||
def test_continuous_across_subtasks(self):
|
||||
"""Test that progress is continuous at subtask boundaries."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
|
||||
y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
|
||||
|
||||
# Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
|
||||
y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
|
||||
|
||||
# Should be equal (P_1 = 0.3)
|
||||
assert abs(y_end_0 - y_start_1) < 1e-6
|
||||
|
||||
# End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
|
||||
y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
|
||||
|
||||
# Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
|
||||
y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
|
||||
|
||||
# Should be equal (P_2 = 0.8)
|
||||
assert abs(y_end_1 - y_start_2) < 1e-6
|
||||
|
||||
|
||||
class TestTemporalProportionsToBreakpoints:
|
||||
"""Tests for temporal_proportions_to_breakpoints.
|
||||
|
||||
Converts temporal proportions to cumulative breakpoints for normalization.
|
||||
Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
|
||||
"""
|
||||
|
||||
def test_basic_conversion(self):
|
||||
"""Test basic conversion from proportions to breakpoints."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
assert breakpoints is not None
|
||||
assert len(breakpoints) == 4
|
||||
assert breakpoints[0] == 0.0
|
||||
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||
assert abs(breakpoints[2] - 0.8) < 1e-6
|
||||
assert breakpoints[3] == 1.0
|
||||
|
||||
def test_dict_input(self):
|
||||
"""Test with dict input."""
|
||||
proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
assert breakpoints is not None
|
||||
assert len(breakpoints) == 4
|
||||
assert breakpoints[0] == 0.0
|
||||
assert breakpoints[-1] == 1.0
|
||||
|
||||
def test_dict_with_subtask_names_order(self):
|
||||
"""Test that subtask_names determines order for dict input."""
|
||||
proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
|
||||
subtask_names = ["a", "b", "c"] # Different order
|
||||
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
|
||||
|
||||
# Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
|
||||
assert abs(breakpoints[1] - 0.2) < 1e-6 # a
|
||||
assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
|
||||
assert breakpoints[3] == 1.0 # a + b + c = 1.0
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions."""
|
||||
proportions = [0.25, 0.25, 0.25, 0.25]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
expected = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||
for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
|
||||
assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
|
||||
|
||||
def test_none_input(self):
|
||||
"""Test that None input returns None."""
|
||||
result = temporal_proportions_to_breakpoints(None)
|
||||
assert result is None
|
||||
|
||||
def test_normalization(self):
|
||||
"""Test that non-normalized proportions are normalized."""
|
||||
# Proportions sum to 2.0, not 1.0
|
||||
proportions = [0.6, 1.0, 0.4]
|
||||
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||
|
||||
# Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
|
||||
assert breakpoints[-1] == 1.0
|
||||
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||
|
||||
|
||||
class TestNormalizeStageTau:
|
||||
"""Tests for normalize_stage_tau.
|
||||
|
||||
Normalizes stage+tau values to [0, 1] using breakpoints.
|
||||
"""
|
||||
|
||||
def test_linear_fallback(self):
|
||||
"""Test linear normalization when only num_stages is provided."""
|
||||
# 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
|
||||
|
||||
# Stage 0 start
|
||||
assert normalize_stage_tau(0.0, num_stages=4) == 0.0
|
||||
|
||||
# Stage 0 end / Stage 1 start
|
||||
assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
|
||||
|
||||
# Stage 1 middle
|
||||
assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
|
||||
|
||||
# Stage 3 end
|
||||
assert normalize_stage_tau(4.0, num_stages=4) == 1.0
|
||||
|
||||
def test_with_custom_breakpoints(self):
|
||||
"""Test with custom breakpoints."""
|
||||
# Non-linear breakpoints
|
||||
breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
|
||||
|
||||
# Stage 0: maps [0, 1) to [0.0, 0.1)
|
||||
assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
|
||||
|
||||
# Stage 1: maps [1, 2) to [0.1, 0.5)
|
||||
assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
|
||||
|
||||
# Stage 2: maps [2, 3) to [0.5, 1.0)
|
||||
assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
|
||||
|
||||
def test_with_temporal_proportions(self):
|
||||
"""Test with temporal proportions (auto-computed breakpoints)."""
|
||||
proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
|
||||
subtask_names = ["a", "b", "c"]
|
||||
|
||||
# Stage 0 end should map to 0.2
|
||||
result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||
assert abs(result - 0.2) < 1e-6
|
||||
|
||||
# Stage 1 end should map to 0.5
|
||||
result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||
assert abs(result - 0.5) < 1e-6
|
||||
|
||||
def test_tensor_input(self):
|
||||
"""Test with tensor input."""
|
||||
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
|
||||
breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
|
||||
|
||||
result = normalize_stage_tau(x, breakpoints=breakpoints)
|
||||
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == x.shape
|
||||
assert abs(result[0].item() - 0.0) < 1e-6
|
||||
assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
|
||||
assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
|
||||
|
||||
def test_clamping(self):
|
||||
"""Test that output is clamped to [0, 1]."""
|
||||
# Below 0
|
||||
assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
|
||||
|
||||
# Above num_stages
|
||||
assert normalize_stage_tau(5.0, num_stages=4) == 1.0
|
||||
|
||||
def test_batch_tensor(self):
|
||||
"""Test with batched tensor."""
|
||||
x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
|
||||
|
||||
result = normalize_stage_tau(x, num_stages=3)
|
||||
|
||||
assert result.shape == (2, 3)
|
||||
assert (result >= 0).all()
|
||||
assert (result <= 1).all()
|
||||
|
||||
def test_requires_one_of_inputs(self):
|
||||
"""Test that at least one input method is required."""
|
||||
with pytest.raises(ValueError):
|
||||
normalize_stage_tau(1.0)
|
||||
|
||||
|
||||
class TestRewindAugmentation:
|
||||
"""Tests for rewind augmentation logic with bidirectional observation sampling.
|
||||
|
||||
Rewind appends frames before the earliest observation frame, going backwards.
|
||||
With bidirectional sampling centered at frame_idx:
|
||||
- Earliest obs frame = frame_idx - half_steps * frame_gap
|
||||
- Rewind goes backwards from that point
|
||||
"""
|
||||
|
||||
def test_rewind_indices_go_backwards_from_earliest_obs(self):
|
||||
"""Rewind indices should go backwards from earliest observation frame."""
|
||||
frame_idx = 300 # Center of bidirectional window
|
||||
ep_start = 0
|
||||
n_obs_steps = 4 # half_steps = 2
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 300 - 2*30 = 240
|
||||
# Rewind goes backwards: 210, 180
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps=n_obs_steps,
|
||||
max_rewind_steps=2,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=2,
|
||||
)
|
||||
|
||||
assert rewind_step == 2
|
||||
assert len(rewind_indices) == 2
|
||||
# First rewind frame is closest to obs window, second is further back
|
||||
assert rewind_indices[0] == 210 # 240 - 30
|
||||
assert rewind_indices[1] == 180 # 240 - 60
|
||||
assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
|
||||
|
||||
def test_rewind_goes_backward_through_history(self):
|
||||
"""Rewind frames should go backward before the observation window."""
|
||||
frame_idx = 450 # Center of bidirectional window
|
||||
ep_start = 0
|
||||
n_obs_steps = 8 # half_steps = 4
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 450 - 4*30 = 330
|
||||
# Rewind from 330: [300, 270, 240]
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx,
|
||||
ep_start,
|
||||
n_obs_steps=n_obs_steps,
|
||||
max_rewind_steps=4,
|
||||
frame_gap=frame_gap,
|
||||
rewind_step=3,
|
||||
)
|
||||
|
||||
assert rewind_step == 3
|
||||
expected = [300, 270, 240] # Going backwards from 330
|
||||
assert rewind_indices == expected
|
||||
|
||||
def test_no_rewind_when_obs_window_at_episode_start(self):
|
||||
"""No rewind when observation window reaches episode start."""
|
||||
frame_idx = 120 # Center of window
|
||||
ep_start = 0
|
||||
n_obs_steps = 8 # half_steps = 4
|
||||
frame_gap = 30
|
||||
|
||||
# Earliest obs frame = 120 - 4*30 = 0 (at episode start)
|
||||
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||
frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
|
||||
)
|
||||
|
||||
# No room for rewind
|
||||
assert rewind_step == 0
|
||||
assert rewind_indices == []
|
||||
|
||||
def test_rewind_targets_are_decreasing(self):
|
||||
"""Progress targets for rewind frames should be decreasing."""
|
||||
# Simulate progress values
|
||||
obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
|
||||
|
||||
# Rewind reverses progress
|
||||
rewind_indices = [4, 3, 2] # Go backwards through indices
|
||||
rewind_progress = [obs_progress[i] for i in rewind_indices]
|
||||
|
||||
# Should be decreasing
|
||||
for i in range(len(rewind_progress) - 1):
|
||||
assert rewind_progress[i] > rewind_progress[i + 1]
|
||||
Reference in New Issue
Block a user