mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Add sarm (#2639)
* add initial modeling * make rewind pretrained policy * add annotation * small fix * add sarm * subtasks * fix spawn * fix rewind discrepancies * Add script to generate embedding for dataset (#2138) * Add generate and validate script * fix precommit * Improve generate embeddings function by using dataset tools (#2206) --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> * cleanup * change order train log * print batch size * update sarm processor * add reward output * change expected features * add image validation * change validation * get state input from dataset stats * raise if no state key is found * pass stats * cleanup and refactor * add episode inddex to complementary data * add subtask init and detection * revert lerobot_train changes * pass dataset metadata to policy * change loadig subtasks * add small logging * fix progress conversion and adding initial frame * use large offset for initial frame (ugly) * Remove rewind, use clip tokenizer * add tests, implement formula 1,2 correctly and cleanup * use task from dataset, cleanup visualizer * simplify * simplify and cleanup code and move compute_temporal_proportions to utils * fix normalization in visualization * Fix visualization and change prompt * fix formatting * add visualize subtask annotations * use qwen thinking * try different prompt * format * update prompt * higher temp, long output * different settings * use instruct * show full resp * split message * Temp: increase tolerance dataset * Fix RA-BC (#2572) * Add next observation loading for RA-BC progress deltas * Compute weights based on temporal progress deltas instead of static rewards * Add hard-masking for negative progress deltas in weight computation * Feat/add dual head (#2582) * Add dual dense sparse head and annotation * Add docs * add dual to procesor * cleanup * change sampling in visualize and cleanup * remove validation * remove compile * Feat/test uniform (#2587) * test uniform * add different string for misaligned * Fix rewind and add tests * uncomment text implementation * run precommit * Add head mode for ra-bc * fix visalization of single task * add * return per sample loss * Fix RA_BC (#2602) * update rabc implementation * compute rabc beforehand * fix import * add only progress calulation * use precomputed progress * multi gpu processing * import * fix dataset meta data extraction * add logging * logging * log * progress per episode * split differently * move clip to gpu * pre decode frames for an episode * fix cuda initalization * fix import * multi processing * rename * fix import * fix * fix rabc * use last known progress if oob * use last known progress if oob * add misalignment loss with random embeddings * discard previous changes * add selection of models to docs for ra_bc * add transformers dep * extend tolerance * initial commit with new codebase * add tests * fix * remove temporal sampler * drop last frame for sampler * use original ref * some fixes * fix visualization * remove smoothing and fix order subtasks * add stride rabc computation * add push to hub * add explanation * add kappa expllaination * better rabc logging * feedback pr * remove dataset tolerance * revert dataset tool * revert dataset changes * add credit * run precommit * change path for generate ra_bc * fix type * include sarm in all in pyproject * fix precommit * lazy import matplotlib * lazy import qwen * remove rich console * skip if transformers is not installed? * run only when we have faker * place transformer lazy loading * Dont test if low transformer version * fix * increase transformer * increase as 4.57.0 is yanked * remove pi from all * go back --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
This commit is contained in:
@@ -42,6 +42,10 @@
|
|||||||
- local: xvla
|
- local: xvla
|
||||||
title: X-VLA
|
title: X-VLA
|
||||||
title: "Policies"
|
title: "Policies"
|
||||||
|
- sections:
|
||||||
|
- local: sarm
|
||||||
|
title: SARM
|
||||||
|
title: "Reward Models"
|
||||||
- sections:
|
- sections:
|
||||||
- local: async
|
- local: async
|
||||||
title: Use Async Inference
|
title: Use Async Inference
|
||||||
|
|||||||
@@ -0,0 +1,586 @@
|
|||||||
|
# SARM: Stage-Aware Reward Modeling
|
||||||
|
|
||||||
|
SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework for long-horizon robot manipulation tasks. This guide covers how to train SARM reward models and optionally use them with Reward-Aligned Behavior Cloning (RA-BC).
|
||||||
|
|
||||||
|
**Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358)
|
||||||
|
|
||||||
|
## Why Reward Models?
|
||||||
|
|
||||||
|
Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
SARM has following features:
|
||||||
|
|
||||||
|
1. **Stage-aware architecture**: Jointly predicts the high-level task stage and fine-grained progress within each stage
|
||||||
|
2. **Subtask annotations**: Uses natural language subtask annotations to derive consistent progress labels
|
||||||
|
3. **Temporal proportions**: Computes dataset-level priors (α̅\_k) for each subtask to normalize progress across variable-length demonstrations
|
||||||
|
|
||||||
|
SARM trains on a compact **stage+tau** target for each frame:
|
||||||
|
|
||||||
|
- **stage**: integer stage index `k ∈ {0, ..., K-1}`
|
||||||
|
- **τ (tau)**: within-stage progress `τ ∈ [0, 1]`
|
||||||
|
- **target encoding**: `y = k + τ` (this is what the dataset processor produces)
|
||||||
|
|
||||||
|
At inference time (and in downstream RA-BC), SARM converts the raw `k + τ` value into a **normalized progress** in `[0, 1]` using dataset-level **temporal proportions** `α̅_k` (stored in `meta/temporal_proportions_*.json`).
|
||||||
|
|
||||||
|
This matches **Formula (2)** from the paper:
|
||||||
|
|
||||||
|
```
|
||||||
|
progress_t = P_{k-1} + α̅_k × τ_t
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
|
||||||
|
- `τ_t = (t - s_k) / (e_k - s_k)` is within-subtask normalized time
|
||||||
|
- `P_{k-1}` is cumulative prior (sum of previous subtask proportions)
|
||||||
|
- `α̅_k` is the temporal proportion for subtask k
|
||||||
|
|
||||||
|
This ensures identical task states map to consistent progress values, even across demonstrations of different lengths.
|
||||||
|
|
||||||
|
## Inputs and Targets (What the new code expects)
|
||||||
|
|
||||||
|
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||||
|
|
||||||
|
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||||
|
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||||
|
- **Builds targets** as `sparse_targets` (and `dense_targets` in `dense_only`/`dual`) using the stage+tau encoding `y = k + τ`
|
||||||
|
- **Masks rewind frames** using a per-sample `lengths` tensor (rewind is a training-time augmentation)
|
||||||
|
|
||||||
|
At minimum, each training sample needs:
|
||||||
|
|
||||||
|
- `task` (string): task description
|
||||||
|
- `policy.image_key` images and `policy.state_key` states from the dataset
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Annotation Modes
|
||||||
|
|
||||||
|
You can choose from **3 annotation modes** that determine how progress labels are computed:
|
||||||
|
|
||||||
|
| Mode | Annotations Required | Heads | Use Case |
|
||||||
|
| -------------- | -------------------- | ---------------------------- | ------------------------------------------------------------ |
|
||||||
|
| `single_stage` | None | Sparse only | Simple tasks, quick experiments, no VLM needed |
|
||||||
|
| `dense_only` | Dense (VLM) | Dual (sparse auto-generated) | Detailed subtask tracking without defining high-level stages |
|
||||||
|
| `dual` | Sparse + Dense (VLM) | Dual | Full SARM paper setup with both granularities |
|
||||||
|
|
||||||
|
### Mode Details
|
||||||
|
|
||||||
|
<hfoptions id="mode_explanation">
|
||||||
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
|
**No annotations required.** The entire episode is treated as a single stage called `"task"`, and progress is linear from 0 to 1 over the episode duration.
|
||||||
|
|
||||||
|
- **Sparse head**: 1 stage ("task"), linear progress
|
||||||
|
- **Dense head**: Not used
|
||||||
|
- **Best for**: Simple tasks, quick experiments, or when VLM annotation is not available
|
||||||
|
|
||||||
|
## Set Up Your Environment
|
||||||
|
|
||||||
|
1. Install LeRobot by following our [Installation Guide](./installation).
|
||||||
|
2. Install SARM dependencies by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[sarm]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Train SARM → 2. Visualize predictions → 3. (Optional) Train policy with RA-BC
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
|
**Only dense (fine-grained) annotations from a VLM.** The sparse head automatically uses a single `"task"` stage covering the full episode, while the dense head learns detailed subtask progression.
|
||||||
|
|
||||||
|
- **Sparse head**: 1 stage ("task"), linear progress (auto-generated)
|
||||||
|
- **Dense head**: Multiple fine-grained stages from VLM annotations
|
||||||
|
- **Best for**: When you want detailed subtask tracking but don't need to define high-level stages
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Annotate (dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dual">
|
||||||
|
|
||||||
|
**Both sparse and dense annotations from VLM.** Full dual-head mode as described in the SARM paper, with both high-level (sparse) and fine-grained (dense) stage predictions.
|
||||||
|
|
||||||
|
- **Sparse head**: High-level stages from VLM annotations
|
||||||
|
- **Dense head**: Fine-grained stages from VLM annotations
|
||||||
|
- **Best for**: Complex multi-stage tasks where both granularities are useful
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Annotate (sparse+dense) → 2. Verify → 3. Train SARM → 4. Visualize → 5. (Optional) Train policy with RA-BC
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 1: Subtask Annotation
|
||||||
|
|
||||||
|
<hfoptions id="annotation_mode">
|
||||||
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
|
**No annotation required!** Skip this step entirely. The model will use the episode's task description and compute linear progress automatically.
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
|
Generate **dense (fine-grained) annotations only** using a VLM. The sparse stage will be auto-generated.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||||
|
--repo-id your-username/your-dataset \
|
||||||
|
--dense-only \
|
||||||
|
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||||
|
--video-key observation.images.base \
|
||||||
|
--num-workers 4 \
|
||||||
|
--push-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
**What gets saved:**
|
||||||
|
|
||||||
|
- `meta/temporal_proportions_sparse.json` - Auto-generated sparse proportions (`{"task": 1.0}`)
|
||||||
|
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||||
|
- Per-episode columns in `episodes/*.parquet`:
|
||||||
|
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||||
|
- (also time-based columns: `dense_subtask_start_times`, `dense_subtask_end_times`)
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dual">
|
||||||
|
|
||||||
|
Generate **both sparse (high-level) and dense (fine-grained) annotations** using a VLM.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||||
|
--repo-id your-username/your-dataset \
|
||||||
|
--sparse-subtasks "Bring arms up from starting position,Fold the towel (3 folds in total)" \
|
||||||
|
--dense-subtasks "Bring robot arms up from starting position,Grab near side and do 1st fold,Grab side and do 2nd fold,Grab side and do 3rd fold to finish folding" \
|
||||||
|
--video-key observation.images.base \
|
||||||
|
--num-workers 4 \
|
||||||
|
--push-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
**What gets saved:**
|
||||||
|
|
||||||
|
- `meta/temporal_proportions_sparse.json` - Sparse temporal proportions
|
||||||
|
- `meta/temporal_proportions_dense.json` - Dense temporal proportions
|
||||||
|
- Per-episode columns in `episodes/*.parquet`:
|
||||||
|
- `sparse_subtask_names`, `sparse_subtask_start_frames`, `sparse_subtask_end_frames`
|
||||||
|
- `dense_subtask_names`, `dense_subtask_start_frames`, `dense_subtask_end_frames`
|
||||||
|
- (also time-based columns: `*_subtask_start_times`, `*_subtask_end_times`)
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
### Annotation Arguments
|
||||||
|
|
||||||
|
| Argument | Description |
|
||||||
|
| ---------------------- | ------------------------------------------------------------------------------- |
|
||||||
|
| `--repo-id` | HuggingFace dataset repository ID |
|
||||||
|
| `--sparse-subtasks` | Comma-separated list of high-level subtask names |
|
||||||
|
| `--dense-subtasks` | Comma-separated list of fine-grained subtask names |
|
||||||
|
| `--dense-only` | Generate only dense annotations (auto-creates sparse "task" stage) |
|
||||||
|
| `--video-key` | Camera/video key to use (e.g., `observation.images.top`) |
|
||||||
|
| `--num-workers` | Number of parallel GPU workers (default: 1) |
|
||||||
|
| `--episodes` | Specific episode indices to annotate (default: all) |
|
||||||
|
| `--skip-existing` | Skip episodes that already have annotations |
|
||||||
|
| `--model` | VLM model (default: `Qwen/Qwen3-VL-30B-A3B-Instruct`) |
|
||||||
|
| `--num-visualizations` | Number of episodes to visualize after annotation (default: 5, set to 0 to skip) |
|
||||||
|
|
||||||
|
> **Note**: After annotation completes, 5 episodes are automatically visualized by default. Use `--num-visualizations 0` to skip this step.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 2: Verify Annotations
|
||||||
|
|
||||||
|
<hfoptions id="verify_mode">
|
||||||
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
|
**No verification needed!** Skip this step.
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
|
Visualize annotations using the `--visualize-only` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||||
|
--repo-id your-username/your-dataset \
|
||||||
|
--visualize-only \
|
||||||
|
--visualize-type dense \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--video-key observation.images.base \
|
||||||
|
--output-dir ./subtask_viz
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dual">
|
||||||
|
|
||||||
|
Visualize annotations using the `--visualize-only` flag:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/data_processing/sarm_annotations/subtask_annotation.py \
|
||||||
|
--repo-id your-username/your-dataset \
|
||||||
|
--visualize-only \
|
||||||
|
--visualize-type both \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--video-key observation.images.base \
|
||||||
|
--output-dir ./subtask_viz
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
This generates visualizations showing video frames with subtask boundaries overlaid and timeline of subtasks.
|
||||||
|
|
||||||
|
### Visualization Arguments
|
||||||
|
|
||||||
|
| Argument | Description |
|
||||||
|
| ---------------------- | -------------------------------------------------------------- |
|
||||||
|
| `--visualize-only` | Only visualize existing annotations (no generation) |
|
||||||
|
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||||
|
| `--visualize-type` | Type of annotations to visualize: `sparse`, `dense`, or `both` |
|
||||||
|
|
||||||
|
**Tip**: If annotations are inaccurate, adjust your subtask descriptions to be more specific and re-run.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 3: Train SARM
|
||||||
|
|
||||||
|
<hfoptions id="train_mode">
|
||||||
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
|
Train with **no annotations** - uses linear progress from 0 to 1:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
|
--policy.type=sarm \
|
||||||
|
--policy.annotation_mode=single_stage \
|
||||||
|
--policy.image_key=observation.images.base \
|
||||||
|
--output_dir=outputs/train/sarm_single \
|
||||||
|
--batch_size=32 \
|
||||||
|
--steps=5000 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.project=sarm \
|
||||||
|
--policy.repo_id=your-username/your-model-name
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
|
Train with **dense annotations only** (sparse auto-generated):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
|
--policy.type=sarm \
|
||||||
|
--policy.annotation_mode=dense_only \
|
||||||
|
--policy.image_key=observation.images.base \
|
||||||
|
--output_dir=outputs/train/sarm_dense \
|
||||||
|
--batch_size=32 \
|
||||||
|
--steps=5000 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.project=sarm \
|
||||||
|
--policy.repo_id=your-username/your-model-name
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dual">
|
||||||
|
|
||||||
|
Train with **both sparse and dense annotations**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
|
--policy.type=sarm \
|
||||||
|
--policy.annotation_mode=dual \
|
||||||
|
--policy.image_key=observation.images.base \
|
||||||
|
--output_dir=outputs/train/sarm_dual \
|
||||||
|
--batch_size=32 \
|
||||||
|
--steps=5000 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.project=sarm \
|
||||||
|
--policy.repo_id=your-username/your-model-name
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
### Multi-GPU Training
|
||||||
|
|
||||||
|
Add `accelerate launch --multi_gpu --num_processes=4` to use multiple GPUs for training.
|
||||||
|
|
||||||
|
### Training Arguments
|
||||||
|
|
||||||
|
| Argument | Description | Default |
|
||||||
|
| -------------------------- | ----------------------------------------------------------------- | ------------------------ |
|
||||||
|
| `--policy.annotation_mode` | `single_stage`, `dense_only`, or `dual` | `single_stage` |
|
||||||
|
| `--policy.image_key` | Camera key for images | `observation.images.top` |
|
||||||
|
| `--policy.state_key` | Key for joint states | `observation.state` |
|
||||||
|
| `--policy.n_obs_steps` | Observation history steps (total obs frames = `n_obs_steps + 1`) | `8` |
|
||||||
|
| `--policy.frame_gap` | Gap (in frames) between sampled observations (at 30 fps: 30 ≈ 1s) | `30` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 4: Visualize Predictions
|
||||||
|
|
||||||
|
Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predictions (and, if available, annotation-derived targets) without writing a parquet file.
|
||||||
|
|
||||||
|
<hfoptions id="viz_mode">
|
||||||
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||||
|
--dataset-repo-id your-username/your-dataset \
|
||||||
|
--reward-model-path your-username/sarm-model \
|
||||||
|
--visualize-only \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--head-mode sparse \
|
||||||
|
--output-dir ./sarm_viz
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||||
|
--dataset-repo-id your-username/your-dataset \
|
||||||
|
--reward-model-path your-username/sarm-model \
|
||||||
|
--visualize-only \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--head-mode dense \
|
||||||
|
--output-dir ./sarm_viz
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="dual">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||||
|
--dataset-repo-id your-username/your-dataset \
|
||||||
|
--reward-model-path your-username/sarm-model \
|
||||||
|
--visualize-only \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--head-mode both \
|
||||||
|
--output-dir ./sarm_viz
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
The visualization shows:
|
||||||
|
|
||||||
|
- **Progress plot**: Predicted progress (and optional annotation-derived “GT” when available and `--stride 1`)
|
||||||
|
- **Stage probabilities**: Stacked area plot of predicted stage probabilities
|
||||||
|
- **Sample frames**: Key frames from the episode with progress/stage labels
|
||||||
|
|
||||||
|
### Visualization Arguments
|
||||||
|
|
||||||
|
| Argument | Description |
|
||||||
|
| ---------------------- | --------------------------------------------------------- |
|
||||||
|
| `--visualize-only` | Only visualize predictions (no RABC computation) |
|
||||||
|
| `--num-visualizations` | Number of episodes to visualize (default: 5) |
|
||||||
|
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` |
|
||||||
|
| `--stride` | Compute every N frames, interpolate the rest (default: 1) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 5 (Optional): Train Policy with RA-BC
|
||||||
|
|
||||||
|
Reward-Aligned Behavior Cloning (RA-BC) uses the trained SARM model to weight training samples based on predicted progress improvement. This requires two steps:
|
||||||
|
|
||||||
|
1. **Precompute progress values** for all frames using the trained SARM model
|
||||||
|
2. **Train policy** with RA-BC weighting using the precomputed values
|
||||||
|
|
||||||
|
### How RA-BC Works
|
||||||
|
|
||||||
|
For each training sample, RA-BC computes the progress delta:
|
||||||
|
|
||||||
|
```
|
||||||
|
r_i = φ(o_{t+Δ}) - φ(o_t)
|
||||||
|
```
|
||||||
|
|
||||||
|
Where `φ` is the SARM progress prediction and `Δ` is the policy's `chunk_size`. Samples with positive progress (good demonstrations) get higher weights, while samples with negative or zero progress get down-weighted.
|
||||||
|
|
||||||
|
The weighting follows **Equations 8-9** from the paper:
|
||||||
|
|
||||||
|
- **Soft weight**: `w̃_i = clip((r_i − (μ − 2σ)) / (4σ + ε), 0, 1)`
|
||||||
|
- **Final weight**: `w_i = 𝟙{r_i > κ} + 𝟙{0 ≤ r_i ≤ κ} × w̃_i`
|
||||||
|
|
||||||
|
### Step 5a: Compute SARM Progress Values
|
||||||
|
|
||||||
|
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||||
|
--dataset-repo-id your-username/your-dataset \
|
||||||
|
--reward-model-path your-username/sarm-model \
|
||||||
|
--head-mode sparse \
|
||||||
|
--num-visualizations 5 \
|
||||||
|
--push-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
This script:
|
||||||
|
|
||||||
|
- Processes all frames and computes progress values
|
||||||
|
- Saves progress values to a parquet file next to the dataset on disk (defaults to `<dataset_root>/sarm_progress.parquet`)
|
||||||
|
- Generates visualizations of the first N episodes (default: 5)
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
|
||||||
|
| Argument | Description | Default |
|
||||||
|
| ---------------------- | -------------------------------------------------------------- | ---------- |
|
||||||
|
| `--reward-model-path` | Path to trained SARM model | (required) |
|
||||||
|
| `--head-mode` | SARM head to use: `sparse`, `dense`, or `both` | `sparse` |
|
||||||
|
| `--device` | Device for inference | `cuda` |
|
||||||
|
| `--visualize-only` | Only visualize predictions (no RA-BC computation) | `false` |
|
||||||
|
| `--num-visualizations` | Number of episodes to visualize (default: 5, set to 0 to skip) | `5` |
|
||||||
|
|
||||||
|
**Output format** (`sarm_progress.parquet`):
|
||||||
|
|
||||||
|
| Column | Description |
|
||||||
|
| ----------------- | ---------------------------------------------- |
|
||||||
|
| `index` | Global frame index in dataset |
|
||||||
|
| `episode_index` | Episode number |
|
||||||
|
| `frame_index` | Local frame index within episode |
|
||||||
|
| `progress_sparse` | Sparse head progress value [0, 1] |
|
||||||
|
| `progress_dense` | Dense head progress value [0, 1] (if computed) |
|
||||||
|
|
||||||
|
### Step 5b: Train Policy with RA-BC
|
||||||
|
|
||||||
|
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
|
--policy.type=pi0 \
|
||||||
|
--use_rabc=true \
|
||||||
|
--rabc_head_mode=sparse \
|
||||||
|
--rabc_kappa=0.01 \
|
||||||
|
--output_dir=outputs/train/policy_rabc \
|
||||||
|
--batch_size=32 \
|
||||||
|
--steps=40000
|
||||||
|
```
|
||||||
|
|
||||||
|
The training script automatically:
|
||||||
|
|
||||||
|
- Loads the precomputed progress values from the parquet file
|
||||||
|
- Uses the policy's `chunk_size` to compute progress deltas (Δ)
|
||||||
|
- Computes sample weights based on progress improvement
|
||||||
|
- Applies weighted loss during training
|
||||||
|
|
||||||
|
**RA-BC Arguments:**
|
||||||
|
|
||||||
|
| Argument | Description | Default |
|
||||||
|
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||||
|
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||||
|
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||||
|
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||||
|
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||||
|
|
||||||
|
### Tuning RA-BC Kappa
|
||||||
|
|
||||||
|
The `kappa` parameter is the threshold that determines which samples get full weight (w=1). Understanding how to tune it is critical for RA-BC to work effectively.
|
||||||
|
|
||||||
|
**How the weighting works:**
|
||||||
|
|
||||||
|
| Condition | Weight |
|
||||||
|
| ------------------- | ----------------------- |
|
||||||
|
| `delta > kappa` | 1.0 (hard threshold) |
|
||||||
|
| `0 ≤ delta ≤ kappa` | Soft weight from Eq. 8 |
|
||||||
|
| `delta < 0` | 0.0 (negative progress) |
|
||||||
|
|
||||||
|
**Diagnosing kappa issues:**
|
||||||
|
|
||||||
|
Monitor these WandB metrics during training:
|
||||||
|
|
||||||
|
| Metric | Healthy Range | Problem Indicator |
|
||||||
|
| ------------------ | ------------- | ------------------------- |
|
||||||
|
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||||
|
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||||
|
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||||
|
|
||||||
|
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||||
|
|
||||||
|
**Setting kappa based on your data:**
|
||||||
|
|
||||||
|
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||||
|
|
||||||
|
```
|
||||||
|
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||||
|
# Most deltas fall in range [0.01, 0.05]
|
||||||
|
|
||||||
|
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||||
|
--rabc_kappa=0.03
|
||||||
|
|
||||||
|
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||||
|
--rabc_kappa=0.05
|
||||||
|
|
||||||
|
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||||
|
--rabc_kappa=0.07
|
||||||
|
```
|
||||||
|
|
||||||
|
**When RA-BC may not help:**
|
||||||
|
|
||||||
|
If your dataset is already high quality (consistent progress across all demonstrations), RA-BC won't provide much benefit since there's nothing to filter.
|
||||||
|
|
||||||
|
### Multi-GPU Training with RA-BC
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=4 \
|
||||||
|
src/lerobot/scripts/lerobot_train.py \
|
||||||
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
|
--policy.type=pi0 \
|
||||||
|
--use_rabc=true \
|
||||||
|
--rabc_kappa=0.01 \
|
||||||
|
--output_dir=outputs/train/policy_rabc \
|
||||||
|
--batch_size=32 \
|
||||||
|
--steps=40000
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tips & Best Practices
|
||||||
|
|
||||||
|
### Choosing a Mode
|
||||||
|
|
||||||
|
- **Start with `single_stage`** for quick experiments - no annotation overhead
|
||||||
|
- Use **`dense_only`** when you want detailed progress tracking but tasks don't have clear high-level stages
|
||||||
|
- Use **`dual`** for complex tasks where both coarse and fine-grained progress is meaningful
|
||||||
|
|
||||||
|
### Annotation Quality
|
||||||
|
|
||||||
|
1. **Be specific with subtask names**: Instead of "fold", use "grab near side and fold toward center"
|
||||||
|
2. **Verify with visualization**: Always check a few episodes before training
|
||||||
|
3. **Consistent naming**: Use the same subtask names across all episodes
|
||||||
|
|
||||||
|
### RA-BC
|
||||||
|
|
||||||
|
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||||
|
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{chen2025sarm,
|
||||||
|
title={SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation},
|
||||||
|
author={Chen, Qianzhong and Yu, Justin and Schwager, Mac and Abbeel, Pieter and Shentu, Yide and Wu, Philipp},
|
||||||
|
journal={arXiv preprint arXiv:2509.25358},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
+3
-1
@@ -96,7 +96,7 @@ dependencies = [
|
|||||||
# Common
|
# Common
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||||
|
|
||||||
# Motors
|
# Motors
|
||||||
@@ -133,6 +133,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
|
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
@@ -173,6 +174,7 @@ all = [
|
|||||||
"lerobot[phone]",
|
"lerobot[phone]",
|
||||||
"lerobot[libero]",
|
"lerobot[libero]",
|
||||||
"lerobot[metaworld]",
|
"lerobot[metaworld]",
|
||||||
|
"lerobot[sarm]"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -65,9 +65,17 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
scheduler: LRSchedulerConfig | None = None
|
scheduler: LRSchedulerConfig | None = None
|
||||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
checkpoint_path: Path | None = field(init=False, default=None)
|
|
||||||
|
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||||
|
use_rabc: bool = False # Enable reward-weighted training
|
||||||
|
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||||
|
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||||
|
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||||
|
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||||
|
|
||||||
# Rename map for the observation to override the image and state keys
|
# Rename map for the observation to override the image and state keys
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
@@ -131,6 +139,14 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_rabc and not self.rabc_progress_path:
|
||||||
|
# Auto-detect from dataset path
|
||||||
|
repo_id = self.dataset.repo_id
|
||||||
|
if self.dataset.root:
|
||||||
|
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||||
|
else:
|
||||||
|
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -29,6 +29,7 @@ __all__ = [
|
|||||||
"PI0Config",
|
"PI0Config",
|
||||||
"PI05Config",
|
"PI05Config",
|
||||||
"SmolVLAConfig",
|
"SmolVLAConfig",
|
||||||
|
"SARMConfig",
|
||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ACTConfig,
|
config: ACTConfig,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: DiffusionConfig,
|
config: DiffusionConfig,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
|||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.utils import validate_visual_features_consistency
|
from lerobot.policies.utils import validate_visual_features_consistency
|
||||||
@@ -105,6 +106,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
|
elif name == "sarm":
|
||||||
|
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||||
|
|
||||||
|
return SARMRewardModel
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -337,6 +342,14 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif isinstance(policy_cfg, SARMConfig):
|
||||||
|
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_sarm_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
dataset_meta=kwargs.get("dataset_meta"),
|
||||||
|
)
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
@@ -435,6 +448,13 @@ def make_policy(
|
|||||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||||
kwargs["config"] = cfg
|
kwargs["config"] = cfg
|
||||||
|
|
||||||
|
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||||
|
if ds_meta is not None and hasattr(ds_meta, "stats"):
|
||||||
|
kwargs["dataset_stats"] = ds_meta.stats
|
||||||
|
|
||||||
|
if ds_meta is not None:
|
||||||
|
kwargs["dataset_meta"] = ds_meta
|
||||||
|
|
||||||
if cfg.pretrained_path:
|
if cfg.pretrained_path:
|
||||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||||
# hyperparameters that we want to vary).
|
# hyperparameters that we want to vary).
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
name = "groot"
|
name = "groot"
|
||||||
config_class = GrootConfig
|
config_class = GrootConfig
|
||||||
|
|
||||||
def __init__(self, config: GrootConfig):
|
def __init__(self, config: GrootConfig, **kwargs):
|
||||||
"""Initialize Groot policy wrapper."""
|
"""Initialize Groot policy wrapper."""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
|||||||
@@ -907,6 +907,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PI0Config,
|
config: PI0Config,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1235,9 +1236,15 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||||
"""Run the batch through the model and compute the loss for training."""
|
"""Run the batch through the model and compute the loss for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch containing observations and actions.
|
||||||
|
reduction: How to reduce the loss. Options:
|
||||||
|
- "mean": Return scalar mean loss (default, backward compatible)
|
||||||
|
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||||
|
"""
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
@@ -1251,11 +1258,17 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
losses = losses[:, :, :original_action_dim]
|
losses = losses[:, :, :original_action_dim]
|
||||||
|
|
||||||
loss = losses.mean()
|
|
||||||
|
|
||||||
loss_dict = {
|
loss_dict = {
|
||||||
"loss": loss.item(),
|
|
||||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, loss_dict
|
if reduction == "none":
|
||||||
|
# Return per-sample losses (B,) by averaging over time and action dims
|
||||||
|
per_sample_loss = losses.mean(dim=(1, 2))
|
||||||
|
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||||
|
return per_sample_loss, loss_dict
|
||||||
|
else:
|
||||||
|
# Default: return scalar mean loss
|
||||||
|
loss = losses.mean()
|
||||||
|
loss_dict["loss"] = loss.item()
|
||||||
|
return loss, loss_dict
|
||||||
|
|||||||
@@ -880,6 +880,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PI05Config,
|
config: PI05Config,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1209,9 +1210,15 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||||
"""Run the batch through the model and compute the loss for training."""
|
"""Run the batch through the model and compute the loss for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch containing observations and actions.
|
||||||
|
reduction: How to reduce the loss. Options:
|
||||||
|
- "mean": Return scalar mean loss (default, backward compatible)
|
||||||
|
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||||
|
"""
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
@@ -1225,11 +1232,17 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
losses = losses[:, :, :original_action_dim]
|
losses = losses[:, :, :original_action_dim]
|
||||||
|
|
||||||
loss = losses.mean()
|
|
||||||
|
|
||||||
loss_dict = {
|
loss_dict = {
|
||||||
"loss": loss.item(),
|
|
||||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return loss, loss_dict
|
if reduction == "none":
|
||||||
|
# Return per-sample losses (B,) by averaging over time and action dims
|
||||||
|
per_sample_loss = losses.mean(dim=(1, 2))
|
||||||
|
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||||
|
return per_sample_loss, loss_dict
|
||||||
|
else:
|
||||||
|
# Default: return scalar mean loss
|
||||||
|
loss = losses.mean()
|
||||||
|
loss_dict["loss"] = loss.item()
|
||||||
|
return loss, loss_dict
|
||||||
|
|||||||
@@ -0,0 +1,870 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Compute SARM progress values for RA-BC (Reward-Aware Behavior Cloning) weighting.
|
||||||
|
|
||||||
|
This script processes all frames in a dataset with SARM to compute progress values [0, 1].
|
||||||
|
The results are saved as a parquet file that can be loaded during training for RA-BC weighting.
|
||||||
|
|
||||||
|
Uses multi-output extraction: each SARM query returns progress for 9 frames, so we only
|
||||||
|
need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Full RA-BC computation with visualizations
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||||
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
|
--reward-model-path pepijn223/sarm_single_uni4
|
||||||
|
|
||||||
|
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||||
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
|
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||||
|
--stride 5
|
||||||
|
|
||||||
|
# Visualize predictions only (no RA-BC computation)
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||||
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
|
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||||
|
--visualize-only \\
|
||||||
|
--num-visualizations 5
|
||||||
|
|
||||||
|
The output is saved to the dataset's local cache directory as 'sarm_progress.parquet'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.gridspec as gridspec
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||||
|
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||||
|
from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
|
||||||
|
|
||||||
|
|
||||||
|
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||||
|
"""Read reward_model_path from parquet metadata if available."""
|
||||||
|
if not parquet_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||||
|
if metadata and b"reward_model_path" in metadata:
|
||||||
|
return metadata[b"reward_model_path"].decode()
|
||||||
|
except Exception: # nosec B110
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_sarm_resources(
|
||||||
|
dataset_repo_id: str,
|
||||||
|
reward_model_path: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> tuple[LeRobotDataset, SARMRewardModel, any]:
|
||||||
|
"""
|
||||||
|
Load SARM model, dataset, and preprocessor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (dataset, reward_model, preprocessor)
|
||||||
|
"""
|
||||||
|
logging.info(f"Loading model: {reward_model_path}")
|
||||||
|
reward_model = SARMRewardModel.from_pretrained(reward_model_path)
|
||||||
|
reward_model.config.device = device
|
||||||
|
reward_model.to(device).eval()
|
||||||
|
|
||||||
|
image_key = reward_model.config.image_key
|
||||||
|
state_key = reward_model.config.state_key
|
||||||
|
delta_indices = reward_model.config.observation_delta_indices
|
||||||
|
|
||||||
|
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||||
|
temp_dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||||
|
fps = temp_dataset.fps
|
||||||
|
|
||||||
|
delta_timestamps = {
|
||||||
|
image_key: [idx / fps for idx in delta_indices],
|
||||||
|
state_key: [idx / fps for idx in delta_indices],
|
||||||
|
}
|
||||||
|
dataset = LeRobotDataset(dataset_repo_id, delta_timestamps=delta_timestamps)
|
||||||
|
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||||
|
|
||||||
|
preprocess, _ = make_sarm_pre_post_processors(
|
||||||
|
config=reward_model.config,
|
||||||
|
dataset_stats=dataset.meta.stats,
|
||||||
|
dataset_meta=dataset.meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset, reward_model, preprocess
|
||||||
|
|
||||||
|
|
||||||
|
def to_numpy_image(img) -> np.ndarray:
|
||||||
|
"""Convert image tensor to numpy uint8 (H, W, C)."""
|
||||||
|
if isinstance(img, torch.Tensor):
|
||||||
|
img = img.cpu().numpy()
|
||||||
|
if img.ndim == 4:
|
||||||
|
# Take center frame for bidirectional sampling
|
||||||
|
img = img[img.shape[0] // 2]
|
||||||
|
if img.shape[0] in [1, 3]:
|
||||||
|
img = np.transpose(img, (1, 2, 0))
|
||||||
|
if img.dtype != np.uint8:
|
||||||
|
# Handle normalized images (may have negative values or values > 1)
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
img = (img - img.min()) / (img.max() - img.min() + 1e-8) # Normalize to [0, 1]
|
||||||
|
img = (img * 255).astype(np.uint8)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_episode(
|
||||||
|
frames, progress_preds, stage_preds, title, output_path, stage_labels, gt_progress=None, gt_stages=None
|
||||||
|
):
|
||||||
|
"""Create visualization with progress plot, stage probabilities, and sample frames.
|
||||||
|
|
||||||
|
Same as sarm_inference_visualization.py
|
||||||
|
"""
|
||||||
|
num_stages = stage_preds.shape[1]
|
||||||
|
colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||||
|
frame_indices = np.arange(len(progress_preds))
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(14, 12))
|
||||||
|
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||||
|
ax_progress, ax_stages, ax_frames = fig.add_subplot(gs[0]), fig.add_subplot(gs[1]), fig.add_subplot(gs[2])
|
||||||
|
|
||||||
|
# Progress plot
|
||||||
|
ax_progress.plot(frame_indices, progress_preds, linewidth=2, color="#2E86AB", label="Predicted")
|
||||||
|
ax_progress.fill_between(frame_indices, 0, progress_preds, alpha=0.3, color="#2E86AB")
|
||||||
|
if gt_progress is not None:
|
||||||
|
ax_progress.plot(
|
||||||
|
frame_indices, gt_progress, linewidth=2, color="#28A745", linestyle="--", label="Ground Truth"
|
||||||
|
)
|
||||||
|
ax_progress.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5)
|
||||||
|
ax_progress.set_ylabel("Progress")
|
||||||
|
ax_progress.set_title(f'Task: "{title}"', fontweight="bold")
|
||||||
|
ax_progress.set_ylim(-0.05, 1.1)
|
||||||
|
ax_progress.legend(loc="upper left")
|
||||||
|
ax_progress.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Stage predictions
|
||||||
|
ax_stages.stackplot(
|
||||||
|
frame_indices,
|
||||||
|
*[stage_preds[:, i] for i in range(num_stages)],
|
||||||
|
colors=colors,
|
||||||
|
alpha=0.8,
|
||||||
|
labels=stage_labels,
|
||||||
|
)
|
||||||
|
if gt_stages is not None:
|
||||||
|
for change_idx in np.where(np.diff(gt_stages) != 0)[0] + 1:
|
||||||
|
ax_stages.axvline(x=change_idx, color="black", linestyle="-", alpha=0.7, linewidth=1.5)
|
||||||
|
ax_stages.set_xlabel("Frame")
|
||||||
|
ax_stages.set_ylabel("Stage Probability")
|
||||||
|
ax_stages.set_ylim(0, 1)
|
||||||
|
ax_stages.legend(loc="upper left", ncol=min(num_stages, 5), fontsize=8)
|
||||||
|
ax_stages.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Sample frames
|
||||||
|
ax_frames.axis("off")
|
||||||
|
num_sample = 8
|
||||||
|
sample_indices = np.linspace(0, len(frames) - 1, num_sample, dtype=int)
|
||||||
|
h, w = frames[0].shape[:2]
|
||||||
|
combined = np.zeros((h, w * num_sample, 3), dtype=np.uint8)
|
||||||
|
for i, idx in enumerate(sample_indices):
|
||||||
|
frame = frames[idx]
|
||||||
|
if frame.shape[-1] == 1:
|
||||||
|
frame = np.repeat(frame, 3, axis=-1)
|
||||||
|
combined[:, i * w : (i + 1) * w] = frame
|
||||||
|
stage_name = stage_labels[np.argmax(stage_preds[idx])][:12]
|
||||||
|
ax_frames.text(
|
||||||
|
i * w + w / 2,
|
||||||
|
-10,
|
||||||
|
f"Frame {idx}\n{progress_preds[idx]:.2f}\n{stage_name}",
|
||||||
|
ha="center",
|
||||||
|
va="top",
|
||||||
|
fontsize=7,
|
||||||
|
)
|
||||||
|
ax_frames.imshow(combined)
|
||||||
|
ax_frames.set_title("Sample Frames", pad=20)
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
print(f"Saved: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_sarm_predictions(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
reward_model: SARMRewardModel,
|
||||||
|
preprocess,
|
||||||
|
episode_indices: list[int],
|
||||||
|
head_mode: str,
|
||||||
|
output_dir: Path,
|
||||||
|
num_display_frames: int = 5,
|
||||||
|
stride: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Visualize SARM predictions for multiple episodes.
|
||||||
|
|
||||||
|
Computes predictions for every frame by default. With stride > 1, computes predictions
|
||||||
|
every N frames and interpolates (progress + stage probabilities) for visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: LeRobotDataset with delta_timestamps configured
|
||||||
|
reward_model: Loaded SARM model
|
||||||
|
preprocess: Preprocessor from make_sarm_pre_post_processors
|
||||||
|
episode_indices: List of episode indices to visualize
|
||||||
|
head_mode: "sparse", "dense", or "both"
|
||||||
|
output_dir: Directory to save visualizations
|
||||||
|
num_display_frames: Number of frames to display in thumbnail strip (default: 5)
|
||||||
|
stride: Compute predictions every N frames, interpolate the rest (default: 1)
|
||||||
|
"""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
image_key = reward_model.config.image_key
|
||||||
|
state_key = reward_model.config.state_key
|
||||||
|
dual_mode = reward_model.config.uses_dual_heads
|
||||||
|
device = reward_model.device
|
||||||
|
|
||||||
|
# Center frame index for bidirectional sampling
|
||||||
|
target_idx = reward_model.config.n_obs_steps // 2
|
||||||
|
|
||||||
|
# Determine which heads to visualize
|
||||||
|
schemes_to_viz = []
|
||||||
|
if head_mode in ("sparse", "both") or not dual_mode:
|
||||||
|
schemes_to_viz.append("sparse")
|
||||||
|
if head_mode in ("dense", "both") and dual_mode:
|
||||||
|
schemes_to_viz.append("dense")
|
||||||
|
|
||||||
|
# Set preprocessor to eval mode to disable augmentations
|
||||||
|
if hasattr(preprocess, "eval"):
|
||||||
|
preprocess.eval()
|
||||||
|
for step in preprocess.steps:
|
||||||
|
if hasattr(step, "eval"):
|
||||||
|
step.eval()
|
||||||
|
|
||||||
|
for episode_idx in episode_indices:
|
||||||
|
ep = dataset.meta.episodes[episode_idx]
|
||||||
|
ep_start = ep["dataset_from_index"]
|
||||||
|
ep_end = ep["dataset_to_index"]
|
||||||
|
task = dataset[ep_start].get("task", "perform the task")
|
||||||
|
num_frames = ep_end - ep_start
|
||||||
|
|
||||||
|
# Select frames for display thumbnails (evenly sampled from begin to end)
|
||||||
|
display_indices = set(
|
||||||
|
[
|
||||||
|
ep_start + int(i * (num_frames - 1) / (num_display_frames - 1))
|
||||||
|
for i in range(num_display_frames)
|
||||||
|
]
|
||||||
|
if num_frames >= num_display_frames
|
||||||
|
else list(range(ep_start, ep_end))
|
||||||
|
)
|
||||||
|
viz_frames = {}
|
||||||
|
|
||||||
|
# Load display frames up-front (stride mode might skip them otherwise).
|
||||||
|
for frame_idx in display_indices:
|
||||||
|
sample = dataset[frame_idx]
|
||||||
|
viz_frames[frame_idx] = to_numpy_image(sample[image_key])
|
||||||
|
|
||||||
|
# Initialize storage for each scheme
|
||||||
|
scheme_data = {}
|
||||||
|
for scheme in schemes_to_viz:
|
||||||
|
num_stages = getattr(reward_model.config, f"num_{scheme}_stages")
|
||||||
|
scheme_data[scheme] = {
|
||||||
|
"viz_progress": np.full(num_frames, np.nan),
|
||||||
|
"viz_stages": np.full((num_frames, num_stages), np.nan),
|
||||||
|
"viz_gt_progress": np.full(num_frames, np.nan),
|
||||||
|
"viz_gt_stages": np.full(num_frames, np.nan),
|
||||||
|
"target_key": f"{scheme}_targets",
|
||||||
|
"num_stages": num_stages,
|
||||||
|
"temporal_props": getattr(reward_model.config, f"{scheme}_temporal_proportions"),
|
||||||
|
"subtask_names": getattr(reward_model.config, f"{scheme}_subtask_names"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if stride > 1:
|
||||||
|
logging.info(f"Visualization stride={stride}: inferring every {stride} frames and interpolating")
|
||||||
|
|
||||||
|
# Process frames one at a time to avoid memory buildup
|
||||||
|
frame_indices = list(range(ep_start, ep_end, stride))
|
||||||
|
if (ep_end - 1) not in frame_indices:
|
||||||
|
frame_indices.append(ep_end - 1)
|
||||||
|
frame_indices = sorted(set(frame_indices))
|
||||||
|
|
||||||
|
for frame_idx in tqdm(frame_indices, desc=f"Episode {episode_idx}", leave=False):
|
||||||
|
local_idx = frame_idx - ep_start
|
||||||
|
sample = dataset[frame_idx]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
image_key: sample[image_key],
|
||||||
|
"task": task,
|
||||||
|
"index": frame_idx,
|
||||||
|
"episode_index": episode_idx,
|
||||||
|
}
|
||||||
|
if state_key in sample:
|
||||||
|
batch[state_key] = sample[state_key]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
processed = preprocess(batch)
|
||||||
|
video_features = processed["video_features"].to(device)
|
||||||
|
text_features = processed["text_features"].to(device)
|
||||||
|
state_features = processed.get("state_features")
|
||||||
|
if state_features is not None:
|
||||||
|
state_features = state_features.to(device)
|
||||||
|
lengths = processed.get("lengths")
|
||||||
|
|
||||||
|
for scheme in schemes_to_viz:
|
||||||
|
sd = scheme_data[scheme]
|
||||||
|
|
||||||
|
# Ground truth
|
||||||
|
# In stride visualization mode, ground-truth plots can be misleading
|
||||||
|
# (only sparse points are available), so we skip GT.
|
||||||
|
if stride == 1 and sd["target_key"] in processed:
|
||||||
|
gt_target = processed[sd["target_key"]][0, target_idx].cpu().item()
|
||||||
|
sd["viz_gt_stages"][local_idx] = int(gt_target)
|
||||||
|
sd["viz_gt_progress"][local_idx] = normalize_stage_tau(
|
||||||
|
gt_target,
|
||||||
|
num_stages=sd["num_stages"],
|
||||||
|
temporal_proportions=sd["temporal_props"],
|
||||||
|
subtask_names=sd["subtask_names"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
reward, stage_probs = reward_model.calculate_rewards(
|
||||||
|
text_embeddings=text_features,
|
||||||
|
video_embeddings=video_features,
|
||||||
|
state_features=state_features,
|
||||||
|
lengths=lengths,
|
||||||
|
return_all_frames=True,
|
||||||
|
return_stages=True,
|
||||||
|
head_mode=scheme,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle both tensor and numpy outputs
|
||||||
|
if isinstance(reward, torch.Tensor):
|
||||||
|
reward = reward.cpu().numpy()
|
||||||
|
stage_probs = stage_probs.cpu().numpy()
|
||||||
|
|
||||||
|
if reward.ndim == 2:
|
||||||
|
sd["viz_progress"][local_idx] = reward[0, target_idx]
|
||||||
|
sd["viz_stages"][local_idx] = stage_probs[0, target_idx, :]
|
||||||
|
else:
|
||||||
|
sd["viz_progress"][local_idx] = reward[target_idx]
|
||||||
|
sd["viz_stages"][local_idx] = stage_probs[target_idx, :]
|
||||||
|
|
||||||
|
# Clear GPU memory after each frame
|
||||||
|
del processed, video_features, text_features
|
||||||
|
if state_features is not None:
|
||||||
|
del state_features
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Interpolate predictions back to per-frame arrays for smooth visualization.
|
||||||
|
if stride > 1:
|
||||||
|
all_local = np.arange(num_frames)
|
||||||
|
for scheme in schemes_to_viz:
|
||||||
|
sd = scheme_data[scheme]
|
||||||
|
|
||||||
|
valid = np.isfinite(sd["viz_progress"])
|
||||||
|
valid_idx = np.where(valid)[0]
|
||||||
|
if valid_idx.size >= 1:
|
||||||
|
sd["viz_progress"] = interpolate_progress(
|
||||||
|
valid_idx, sd["viz_progress"][valid_idx], all_local
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_interp = np.zeros_like(sd["viz_stages"], dtype=np.float32)
|
||||||
|
for s in range(sd["num_stages"]):
|
||||||
|
stage_interp[:, s] = interpolate_progress(
|
||||||
|
valid_idx, sd["viz_stages"][valid_idx, s], all_local
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_interp = np.clip(stage_interp, 0.0, 1.0)
|
||||||
|
row_sums = stage_interp.sum(axis=1, keepdims=True)
|
||||||
|
nz = row_sums.squeeze(-1) > 0
|
||||||
|
stage_interp[nz] = stage_interp[nz] / row_sums[nz]
|
||||||
|
sd["viz_stages"] = stage_interp
|
||||||
|
else:
|
||||||
|
# No valid points: keep NaNs/zeros; visualization will be empty.
|
||||||
|
sd["viz_stages"] = np.nan_to_num(sd["viz_stages"], nan=0.0)
|
||||||
|
|
||||||
|
# Generate visualization for each head
|
||||||
|
ordered_viz_frames = [viz_frames[idx] for idx in sorted(display_indices)]
|
||||||
|
for scheme in schemes_to_viz:
|
||||||
|
sd = scheme_data[scheme]
|
||||||
|
stage_labels = sd["subtask_names"] or [f"Stage {i + 1}" for i in range(sd["num_stages"])]
|
||||||
|
viz_path = output_dir / f"sarm_prediction_ep{episode_idx}_{scheme}.png"
|
||||||
|
|
||||||
|
visualize_episode(
|
||||||
|
frames=np.array(ordered_viz_frames),
|
||||||
|
progress_preds=sd["viz_progress"],
|
||||||
|
stage_preds=sd["viz_stages"],
|
||||||
|
title=f"{task} (Episode {episode_idx})",
|
||||||
|
output_path=viz_path,
|
||||||
|
stage_labels=stage_labels,
|
||||||
|
gt_progress=sd["viz_gt_progress"] if not np.all(np.isnan(sd["viz_gt_progress"])) else None,
|
||||||
|
gt_stages=sd["viz_gt_stages"] if not np.all(np.isnan(sd["viz_gt_stages"])) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear memory between episodes
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logging.info(f"Visualizations saved to: {output_dir.absolute()}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_all_frame_indices(ep_start: int, ep_end: int, frame_gap: int = 30) -> list[int]:
|
||||||
|
"""Generate all frame indices, ordered by offset for cache-friendly access.
|
||||||
|
|
||||||
|
Orders frames as: [0, 30, 60...], [1, 31, 61...], ..., [29, 59, 89...]
|
||||||
|
This groups frames that share similar temporal windows together.
|
||||||
|
"""
|
||||||
|
num_frames = ep_end - ep_start
|
||||||
|
indices = []
|
||||||
|
for offset in range(frame_gap):
|
||||||
|
for frame_rel in range(offset, num_frames, frame_gap):
|
||||||
|
indices.append(ep_start + frame_rel)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_progress(
|
||||||
|
computed_indices: np.ndarray,
|
||||||
|
computed_values: np.ndarray,
|
||||||
|
all_indices: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Linearly interpolate values to fill in gaps (robust to NaNs / edge cases)."""
|
||||||
|
computed_indices = np.asarray(computed_indices)
|
||||||
|
computed_values = np.asarray(computed_values)
|
||||||
|
all_indices = np.asarray(all_indices)
|
||||||
|
|
||||||
|
mask = np.isfinite(computed_values)
|
||||||
|
if mask.sum() == 0:
|
||||||
|
return np.full(all_indices.shape, np.nan, dtype=np.float32)
|
||||||
|
if mask.sum() == 1:
|
||||||
|
return np.full(all_indices.shape, float(computed_values[mask][0]), dtype=np.float32)
|
||||||
|
|
||||||
|
out = np.interp(all_indices, computed_indices[mask], computed_values[mask])
|
||||||
|
return out.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sarm_progress(
|
||||||
|
dataset_repo_id: str,
|
||||||
|
reward_model_path: str,
|
||||||
|
output_path: str | None = None,
|
||||||
|
head_mode: str = "sparse",
|
||||||
|
device: str = "cuda",
|
||||||
|
num_visualizations: int = 5,
|
||||||
|
output_dir: str = "./sarm_viz",
|
||||||
|
stride: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute SARM progress predictions for all frames in a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_repo_id: HuggingFace dataset repo ID or local path
|
||||||
|
reward_model_path: Path to pretrained SARM model
|
||||||
|
output_path: Path to save results. If None, saves to dataset's cache directory
|
||||||
|
head_mode: SARM head to use ("sparse", "dense", or "both")
|
||||||
|
device: Device to use for inference
|
||||||
|
num_visualizations: Number of episodes to visualize (0 to skip)
|
||||||
|
output_dir: Directory to save visualizations
|
||||||
|
stride: Compute progress every N frames, interpolate the rest (default: 1 = every frame)
|
||||||
|
"""
|
||||||
|
dataset, reward_model, preprocess = load_sarm_resources(dataset_repo_id, reward_model_path, device)
|
||||||
|
|
||||||
|
# Set preprocessor to eval mode to disable augmentations
|
||||||
|
if hasattr(preprocess, "eval"):
|
||||||
|
preprocess.eval()
|
||||||
|
for step in preprocess.steps:
|
||||||
|
if hasattr(step, "eval"):
|
||||||
|
step.eval()
|
||||||
|
|
||||||
|
image_key = reward_model.config.image_key
|
||||||
|
state_key = reward_model.config.state_key
|
||||||
|
frame_gap = reward_model.config.frame_gap
|
||||||
|
num_episodes = dataset.num_episodes
|
||||||
|
total_frames = dataset.num_frames
|
||||||
|
logging.info(f"Processing {total_frames} frames across {num_episodes} episodes")
|
||||||
|
|
||||||
|
# Determine which heads to compute
|
||||||
|
dual_mode = reward_model.config.uses_dual_heads
|
||||||
|
compute_sparse = head_mode in ("sparse", "both") or not dual_mode
|
||||||
|
compute_dense = head_mode in ("dense", "both") and dual_mode
|
||||||
|
|
||||||
|
# Storage arrays
|
||||||
|
all_indices = []
|
||||||
|
all_episode_indices = []
|
||||||
|
all_frame_indices = []
|
||||||
|
all_progress_sparse = [] if compute_sparse else None
|
||||||
|
all_progress_dense = [] if compute_dense else None
|
||||||
|
|
||||||
|
if stride > 1:
|
||||||
|
logging.info(f"Using stride={stride}: computing every {stride} frames, interpolating the rest")
|
||||||
|
|
||||||
|
# Process all episodes
|
||||||
|
for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
|
||||||
|
ep = dataset.meta.episodes[episode_idx]
|
||||||
|
ep_start = ep["dataset_from_index"]
|
||||||
|
ep_end = ep["dataset_to_index"]
|
||||||
|
|
||||||
|
# Get task description
|
||||||
|
task = dataset[ep_start].get("task", "perform the task")
|
||||||
|
|
||||||
|
# Generate frames to compute (with stride applied)
|
||||||
|
all_ep_indices = generate_all_frame_indices(ep_start, ep_end, frame_gap)
|
||||||
|
if stride > 1:
|
||||||
|
# Only compute every stride-th frame (relative to episode start)
|
||||||
|
compute_indices = [idx for idx in all_ep_indices if (idx - ep_start) % stride == 0]
|
||||||
|
# Always include last frame for better interpolation at episode end
|
||||||
|
last_frame = ep_end - 1
|
||||||
|
if last_frame not in compute_indices:
|
||||||
|
compute_indices.append(last_frame)
|
||||||
|
compute_indices = sorted(set(compute_indices))
|
||||||
|
else:
|
||||||
|
compute_indices = all_ep_indices
|
||||||
|
|
||||||
|
center_idx = reward_model.config.n_obs_steps // 2 # Center of bidirectional window
|
||||||
|
|
||||||
|
# Dictionary to collect results
|
||||||
|
frame_results = {}
|
||||||
|
|
||||||
|
for query_idx in tqdm(compute_indices, desc=f" Ep {episode_idx}", leave=False):
|
||||||
|
try:
|
||||||
|
sample = dataset[query_idx]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
image_key: sample[image_key],
|
||||||
|
"task": task,
|
||||||
|
"index": query_idx,
|
||||||
|
"episode_index": episode_idx,
|
||||||
|
}
|
||||||
|
if state_key in sample:
|
||||||
|
batch[state_key] = sample[state_key]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
processed = preprocess(batch)
|
||||||
|
video_features = processed["video_features"].to(device)
|
||||||
|
text_features = processed["text_features"].to(device)
|
||||||
|
state_features = processed.get("state_features")
|
||||||
|
if state_features is not None:
|
||||||
|
state_features = state_features.to(device)
|
||||||
|
lengths = processed.get("lengths")
|
||||||
|
|
||||||
|
sparse_val = np.nan
|
||||||
|
dense_val = np.nan
|
||||||
|
|
||||||
|
# Compute sparse prediction for center frame
|
||||||
|
if compute_sparse:
|
||||||
|
sparse_progress = reward_model.calculate_rewards(
|
||||||
|
text_embeddings=text_features,
|
||||||
|
video_embeddings=video_features,
|
||||||
|
state_features=state_features,
|
||||||
|
lengths=lengths,
|
||||||
|
return_all_frames=True,
|
||||||
|
head_mode="sparse",
|
||||||
|
)
|
||||||
|
sparse_val = float(
|
||||||
|
sparse_progress[0, center_idx]
|
||||||
|
if sparse_progress.ndim == 2
|
||||||
|
else sparse_progress[center_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute dense prediction for center frame
|
||||||
|
if compute_dense:
|
||||||
|
dense_progress = reward_model.calculate_rewards(
|
||||||
|
text_embeddings=text_features,
|
||||||
|
video_embeddings=video_features,
|
||||||
|
state_features=state_features,
|
||||||
|
lengths=lengths,
|
||||||
|
return_all_frames=True,
|
||||||
|
head_mode="dense",
|
||||||
|
)
|
||||||
|
dense_val = float(
|
||||||
|
dense_progress[0, center_idx]
|
||||||
|
if dense_progress.ndim == 2
|
||||||
|
else dense_progress[center_idx]
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_results[query_idx] = (sparse_val, dense_val)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to process frame {query_idx}: {e}")
|
||||||
|
|
||||||
|
# Interpolate to get values for all frames
|
||||||
|
computed_indices = np.array(sorted(frame_results.keys()))
|
||||||
|
computed_sparse = (
|
||||||
|
np.array([frame_results[i][0] for i in computed_indices]) if compute_sparse else None
|
||||||
|
)
|
||||||
|
computed_dense = np.array([frame_results[i][1] for i in computed_indices]) if compute_dense else None
|
||||||
|
|
||||||
|
# All frame indices for this episode
|
||||||
|
all_frame_idx_array = np.arange(ep_start, ep_end)
|
||||||
|
|
||||||
|
if stride > 1 and len(computed_indices) > 1:
|
||||||
|
# Interpolate progress values
|
||||||
|
if compute_sparse:
|
||||||
|
interp_sparse = interpolate_progress(computed_indices, computed_sparse, all_frame_idx_array)
|
||||||
|
if compute_dense:
|
||||||
|
interp_dense = interpolate_progress(computed_indices, computed_dense, all_frame_idx_array)
|
||||||
|
else:
|
||||||
|
# No interpolation needed
|
||||||
|
interp_sparse = computed_sparse if compute_sparse else None
|
||||||
|
interp_dense = computed_dense if compute_dense else None
|
||||||
|
|
||||||
|
# Store results for all frames
|
||||||
|
for i, frame_idx in enumerate(all_frame_idx_array):
|
||||||
|
local_idx = frame_idx - ep_start
|
||||||
|
all_indices.append(frame_idx)
|
||||||
|
all_episode_indices.append(episode_idx)
|
||||||
|
all_frame_indices.append(local_idx)
|
||||||
|
if compute_sparse:
|
||||||
|
if stride > 1 and len(computed_indices) > 1:
|
||||||
|
all_progress_sparse.append(float(interp_sparse[i]))
|
||||||
|
elif frame_idx in frame_results:
|
||||||
|
all_progress_sparse.append(frame_results[frame_idx][0])
|
||||||
|
else:
|
||||||
|
all_progress_sparse.append(np.nan)
|
||||||
|
if compute_dense:
|
||||||
|
if stride > 1 and len(computed_indices) > 1:
|
||||||
|
all_progress_dense.append(float(interp_dense[i]))
|
||||||
|
elif frame_idx in frame_results:
|
||||||
|
all_progress_dense.append(frame_results[frame_idx][1])
|
||||||
|
else:
|
||||||
|
all_progress_dense.append(np.nan)
|
||||||
|
|
||||||
|
# Create output table
|
||||||
|
table_data = {
|
||||||
|
"index": np.array(all_indices, dtype=np.int64),
|
||||||
|
"episode_index": np.array(all_episode_indices, dtype=np.int64),
|
||||||
|
"frame_index": np.array(all_frame_indices, dtype=np.int64),
|
||||||
|
}
|
||||||
|
if compute_sparse:
|
||||||
|
table_data["progress_sparse"] = np.array(all_progress_sparse, dtype=np.float32)
|
||||||
|
if compute_dense:
|
||||||
|
table_data["progress_dense"] = np.array(all_progress_dense, dtype=np.float32)
|
||||||
|
|
||||||
|
# Sort by index
|
||||||
|
df = pa.table(table_data).to_pandas()
|
||||||
|
df = df.sort_values("index").reset_index(drop=True)
|
||||||
|
final_table = pa.Table.from_pandas(df, preserve_index=False)
|
||||||
|
|
||||||
|
# Add metadata with reward model path
|
||||||
|
metadata = {b"reward_model_path": reward_model_path.encode()}
|
||||||
|
final_table = final_table.replace_schema_metadata(metadata)
|
||||||
|
|
||||||
|
# Determine output path
|
||||||
|
output_path = Path(dataset.root) / "sarm_progress.parquet" if output_path is None else Path(output_path)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
pq.write_table(final_table, output_path)
|
||||||
|
logging.info(f"Saved {len(final_table)} frame progress values to {output_path}")
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
if "progress_sparse" in df.columns:
|
||||||
|
valid = df["progress_sparse"].dropna()
|
||||||
|
logging.info(
|
||||||
|
f"Sparse progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||||
|
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "progress_dense" in df.columns:
|
||||||
|
valid = df["progress_dense"].dropna()
|
||||||
|
logging.info(
|
||||||
|
f"Dense progress: mean={valid.mean():.4f}, std={valid.std():.4f}, "
|
||||||
|
f"min={valid.min():.4f}, max={valid.max():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Visualize episodes after processing
|
||||||
|
if num_visualizations > 0:
|
||||||
|
viz_episodes = list(range(min(num_visualizations, num_episodes)))
|
||||||
|
logging.info(f"Generating {len(viz_episodes)} visualizations...")
|
||||||
|
visualize_sarm_predictions(
|
||||||
|
dataset=dataset,
|
||||||
|
reward_model=reward_model,
|
||||||
|
preprocess=preprocess,
|
||||||
|
episode_indices=viz_episodes,
|
||||||
|
head_mode=head_mode,
|
||||||
|
output_dir=Path(output_dir),
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Compute SARM progress values for RA-BC weighting or visualize SARM predictions",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
Examples:
|
||||||
|
# Full RA-BC computation with visualizations
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||||
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
|
--reward-model-path pepijn223/sarm_single_uni4
|
||||||
|
|
||||||
|
# Visualize predictions only (no RA-BC computation)
|
||||||
|
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||||
|
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||||
|
--reward-model-path pepijn223/sarm_single_uni4 \\
|
||||||
|
--visualize-only \\
|
||||||
|
--num-visualizations 10
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="HuggingFace dataset repo ID or local path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to pretrained SARM model (reads from existing parquet metadata if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Output path for parquet. If not set, saves to dataset's cache directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head-mode",
|
||||||
|
type=str,
|
||||||
|
default="sparse",
|
||||||
|
choices=["sparse", "dense", "both"],
|
||||||
|
help="SARM head to use (default: sparse)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use (default: cuda)",
|
||||||
|
)
|
||||||
|
# Visualization options
|
||||||
|
parser.add_argument(
|
||||||
|
"--visualize-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only visualize SARM predictions (no RA-BC computation)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-visualizations",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of episodes to visualize (default: 5, set to 0 to skip)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./sarm_viz",
|
||||||
|
help="Output directory for visualizations (default: ./sarm_viz)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Upload progress file to the dataset repo on HuggingFace Hub",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stride",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Compute progress every N frames, interpolate the rest (default: 1 = every frame)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||||
|
|
||||||
|
# Try to get reward_model_path from parquet metadata if not provided
|
||||||
|
reward_model_path = args.reward_model_path
|
||||||
|
if reward_model_path is None:
|
||||||
|
# Load dataset to find parquet path
|
||||||
|
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||||
|
parquet_path = Path(temp_dataset.root) / "sarm_progress.parquet"
|
||||||
|
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||||
|
if reward_model_path:
|
||||||
|
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"--reward-model-path is required (no existing parquet with model metadata found)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle visualize-only mode
|
||||||
|
if args.visualize_only:
|
||||||
|
dataset, reward_model, preprocess = load_sarm_resources(
|
||||||
|
args.dataset_repo_id, reward_model_path, args.device
|
||||||
|
)
|
||||||
|
logging.info(f"Visualization-only mode: visualizing {args.num_visualizations} episodes")
|
||||||
|
viz_episodes = list(range(min(args.num_visualizations, dataset.num_episodes)))
|
||||||
|
visualize_sarm_predictions(
|
||||||
|
dataset=dataset,
|
||||||
|
reward_model=reward_model,
|
||||||
|
preprocess=preprocess,
|
||||||
|
episode_indices=viz_episodes,
|
||||||
|
head_mode=args.head_mode,
|
||||||
|
output_dir=Path(args.output_dir),
|
||||||
|
stride=args.stride,
|
||||||
|
)
|
||||||
|
print(f"\nVisualizations saved to: {Path(args.output_dir).absolute()}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Full RABC computation (compute_sarm_progress loads model/dataset itself)
|
||||||
|
output_path = compute_sarm_progress(
|
||||||
|
dataset_repo_id=args.dataset_repo_id,
|
||||||
|
reward_model_path=reward_model_path,
|
||||||
|
output_path=args.output_path,
|
||||||
|
head_mode=args.head_mode,
|
||||||
|
device=args.device,
|
||||||
|
num_visualizations=args.num_visualizations,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
stride=args.stride,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nSARM progress values saved to: {output_path}")
|
||||||
|
|
||||||
|
# Upload to Hub if requested
|
||||||
|
if args.push_to_hub:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
api = HfApi()
|
||||||
|
hub_path = "sarm_progress.parquet"
|
||||||
|
|
||||||
|
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=str(output_path),
|
||||||
|
path_in_repo=hub_path,
|
||||||
|
repo_id=args.dataset_repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Successfully uploaded to: https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nTo use in training, add to your config:")
|
||||||
|
print(" use_rabc: true")
|
||||||
|
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||||
|
print(" rabc_head_mode: sparse # or dense")
|
||||||
|
else:
|
||||||
|
print("\nTo use in training, add to your config:")
|
||||||
|
print(" use_rabc: true")
|
||||||
|
print(f" rabc_progress_path: {output_path}")
|
||||||
|
print(" rabc_head_mode: sparse # or dense")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||||
|
Paper: https://arxiv.org/abs/2509.25358
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("sarm")
|
||||||
|
@dataclass
|
||||||
|
class SARMConfig(PreTrainedConfig):
|
||||||
|
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||||
|
|
||||||
|
Supports three annotation modes:
|
||||||
|
|
||||||
|
1. single_stage (default): No annotations needed. Uses the episode's task description
|
||||||
|
as a single stage covering the entire episode.
|
||||||
|
|
||||||
|
2. dense_only: Uses dense (fine-grained) annotations from VLM, with an auto-generated
|
||||||
|
single sparse "task" stage covering the full episode. The dense head learns detailed
|
||||||
|
subtask progression while sparse provides overall task completion.
|
||||||
|
|
||||||
|
3. dual: Full dual-head mode with both sparse (high-level) and dense (fine-grained)
|
||||||
|
annotations from VLM. Both heads are trained on their respective annotations.
|
||||||
|
|
||||||
|
The annotation_mode determines how sparse_temporal_proportions and dense_temporal_proportions
|
||||||
|
are loaded/generated during model initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotation_mode: str = "single_stage" # "single_stage", "dense_only", or "dual"
|
||||||
|
n_obs_steps: int = 8 # Number of observation history steps
|
||||||
|
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||||
|
max_rewind_steps: int = 4 # Maximum rewind steps for temporal augmentation
|
||||||
|
|
||||||
|
# Total frames = 1 + n_obs_steps + max_rewind_steps (computed in property)
|
||||||
|
# During training with rewind: [obs_frames] + [rewind_frames]
|
||||||
|
# During inference: [obs_frames] only
|
||||||
|
|
||||||
|
# Architecture params
|
||||||
|
image_dim: int = 512
|
||||||
|
text_dim: int = 512
|
||||||
|
hidden_dim: int = 768
|
||||||
|
num_heads: int = 12
|
||||||
|
num_layers: int = 8
|
||||||
|
max_state_dim: int = 32
|
||||||
|
drop_n_last_frames: int = 1
|
||||||
|
batch_size: int = 64
|
||||||
|
clip_batch_size: int = 64
|
||||||
|
dropout: float = 0.1
|
||||||
|
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||||
|
|
||||||
|
rewind_probability: float = 0.8
|
||||||
|
language_perturbation_probability: float = 0.2
|
||||||
|
|
||||||
|
# Sparse annotations (high-level stages)
|
||||||
|
num_sparse_stages: int = 1
|
||||||
|
sparse_subtask_names: list | None = None
|
||||||
|
sparse_temporal_proportions: list | None = None
|
||||||
|
|
||||||
|
# Dense annotations (fine-grained stages)
|
||||||
|
num_dense_stages: int | None = None
|
||||||
|
dense_subtask_names: list | None = None
|
||||||
|
dense_temporal_proportions: list | None = None
|
||||||
|
|
||||||
|
pretrained_model_path: str | None = None
|
||||||
|
device: str | None = None
|
||||||
|
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||||
|
state_key: str = "observation.state"
|
||||||
|
|
||||||
|
# Populated by the processor (video_features, state_features, text_features)
|
||||||
|
input_features: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
# Output features (updated in __post_init__)
|
||||||
|
output_features: dict = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||||
|
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
|
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||||
|
"REWARD": NormalizationMode.IDENTITY,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.annotation_mode == "single_stage":
|
||||||
|
# Use task description as stage name, full episode as one stage
|
||||||
|
self.num_sparse_stages = 1
|
||||||
|
self.sparse_subtask_names = ["task"]
|
||||||
|
self.sparse_temporal_proportions = [1.0]
|
||||||
|
self.num_dense_stages = None
|
||||||
|
self.dense_subtask_names = None
|
||||||
|
self.dense_temporal_proportions = None
|
||||||
|
|
||||||
|
elif self.annotation_mode == "dense_only":
|
||||||
|
self.num_sparse_stages = 1
|
||||||
|
self.sparse_subtask_names = ["task"]
|
||||||
|
self.sparse_temporal_proportions = [1.0]
|
||||||
|
|
||||||
|
self.input_features = {}
|
||||||
|
self.output_features = {}
|
||||||
|
|
||||||
|
if self.image_key:
|
||||||
|
self.input_features[self.image_key] = PolicyFeature(shape=(480, 640, 3), type=FeatureType.VISUAL)
|
||||||
|
|
||||||
|
self.input_features[self.state_key] = PolicyFeature(
|
||||||
|
shape=(self.max_state_dim,),
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update output features based on annotation_mode
|
||||||
|
if self.annotation_mode in ["dense_only", "dual"]:
|
||||||
|
self.output_features["sparse_stage"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
self.output_features["sparse_progress"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
dense_stages = self.num_dense_stages or self.num_sparse_stages
|
||||||
|
self.output_features["dense_stage"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, dense_stages), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
self.output_features["dense_progress"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.output_features["sparse_stage"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, self.num_sparse_stages), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
self.output_features["sparse_progress"] = PolicyFeature(
|
||||||
|
shape=(self.num_frames, 1), type=FeatureType.REWARD
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.max_rewind_steps >= self.n_obs_steps:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_rewind_steps ({self.max_rewind_steps}) must be less than n_obs_steps ({self.n_obs_steps})"
|
||||||
|
)
|
||||||
|
if self.num_sparse_stages < 1:
|
||||||
|
raise ValueError(f"num_sparse_stages must be at least 1, got {self.num_sparse_stages}")
|
||||||
|
if (
|
||||||
|
self.annotation_mode in ["dense_only", "dual"]
|
||||||
|
and self.num_dense_stages is not None
|
||||||
|
and self.num_dense_stages < 2
|
||||||
|
):
|
||||||
|
raise ValueError(f"num_dense_stages must be at least 2, got {self.num_dense_stages}")
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
"""Get default optimizer configuration for SARM training."""
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=5e-5,
|
||||||
|
weight_decay=1e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
|
"""Get default learning rate scheduler configuration."""
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=5e-5,
|
||||||
|
decay_lr=5e-6,
|
||||||
|
num_warmup_steps=500,
|
||||||
|
num_decay_steps=50000,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uses_dual_heads(self) -> bool:
|
||||||
|
"""Whether the model uses dual heads (dense_only or dual annotation modes)."""
|
||||||
|
return self.annotation_mode in ["dense_only", "dual"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames(self) -> int:
|
||||||
|
"""Total number of frames in sequence.
|
||||||
|
|
||||||
|
For training: 1 + n_obs_steps + max_rewind_steps
|
||||||
|
The sequence is: [obs_frames (n_obs_steps + 1)] + [rewind_frames (max_rewind_steps)]
|
||||||
|
"""
|
||||||
|
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_length(self) -> int:
|
||||||
|
return self.num_frames
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int]:
|
||||||
|
"""Bidirectional frame sampling centered on target frame.
|
||||||
|
|
||||||
|
Example with n_obs_steps=8, gap=30:
|
||||||
|
Before: [-120, -90, -60, -30] (4 frames)
|
||||||
|
Current: [0] (1 frame)
|
||||||
|
After: [30, 60, 90, 120] (4 frames)
|
||||||
|
Total: 9 frames
|
||||||
|
"""
|
||||||
|
half_steps = self.n_obs_steps // 2
|
||||||
|
|
||||||
|
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||||
|
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||||
|
obs_deltas = past_deltas + [0] + future_deltas
|
||||||
|
|
||||||
|
# Rewind placeholders
|
||||||
|
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||||
|
|
||||||
|
return obs_deltas + rewind_deltas
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> None:
|
||||||
|
"""SARM is a reward model, not an action policy."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
@@ -0,0 +1,793 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||||
|
# and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/2509.25358
|
||||||
|
|
||||||
|
- StageTransformer: Predicts stage classification (sparse/dense)
|
||||||
|
- SubtaskTransformer: Predicts within-stage progress (tau) conditioned on stage
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
|
from lerobot.policies.sarm.sarm_utils import (
|
||||||
|
normalize_stage_tau,
|
||||||
|
pad_state_to_max_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StageTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Stage classification transformer for SARM.
|
||||||
|
|
||||||
|
Predicts which stage/subtask the current frame belongs to.
|
||||||
|
Supports both sparse (high-level) and dense (fine-grained) annotation schemes.
|
||||||
|
|
||||||
|
Input streams: [vis_proj, lang_proj, state_proj] concatenated -> (B, N+2, T, D)
|
||||||
|
Output: stage logits (B, T, num_classes)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 512,
|
||||||
|
vis_emb_dim: int = 512,
|
||||||
|
text_emb_dim: int = 512,
|
||||||
|
state_dim: int = 32,
|
||||||
|
n_layers: int = 6,
|
||||||
|
n_heads: int = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
num_cameras: int = 1,
|
||||||
|
num_classes_sparse: int = 4,
|
||||||
|
num_classes_dense: int = 8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
|
||||||
|
# Projections
|
||||||
|
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||||
|
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||||
|
self.state_proj = nn.Linear(state_dim, d_model)
|
||||||
|
|
||||||
|
# Encoder
|
||||||
|
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||||
|
self.transformer = nn.TransformerEncoder(enc_layer, n_layers)
|
||||||
|
|
||||||
|
# Positional bias on first visual frame
|
||||||
|
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||||
|
|
||||||
|
# Shared fusion MLP
|
||||||
|
# Fuses (num_cameras + 2) streams: cameras + lang + state
|
||||||
|
fused_in = d_model * (num_cameras + 2)
|
||||||
|
self.fusion_backbone = nn.Sequential(
|
||||||
|
nn.LayerNorm(fused_in),
|
||||||
|
nn.Linear(fused_in, d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scheme-specific heads
|
||||||
|
self.heads = nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"sparse": nn.Linear(d_model, num_classes_sparse),
|
||||||
|
"dense": nn.Linear(d_model, num_classes_dense),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||||
|
"""
|
||||||
|
Prepare language embeddings for fusion.
|
||||||
|
|
||||||
|
Accepts lang_emb of shape:
|
||||||
|
- (B, text_emb_dim) -> broadcast across time
|
||||||
|
- (B, T, text_emb_dim) -> per-timestep (dense annotation mode)
|
||||||
|
|
||||||
|
Returns: (B, 1, T, D)
|
||||||
|
"""
|
||||||
|
if lang_emb.dim() == 3:
|
||||||
|
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||||
|
lang_proj = self.lang_proj(lang_emb).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# (B, E) -> (B, 1, 1, D) -> expand to (B, 1, T, D)
|
||||||
|
lang_proj = self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||||
|
return lang_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||||
|
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||||
|
state: torch.Tensor, # (B, T, state_dim)
|
||||||
|
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||||
|
scheme: str = "sparse", # "sparse" or "dense"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for stage classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_seq: Image embeddings (B, N, T, vis_emb_dim) where N=num_cameras
|
||||||
|
lang_emb: Language embeddings (B, E) or (B, T, E) for dense
|
||||||
|
state: State features (B, T, state_dim)
|
||||||
|
lengths: Valid sequence lengths (B,) for masking
|
||||||
|
scheme: "sparse" or "dense" for head selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Stage logits (B, T, num_classes)
|
||||||
|
"""
|
||||||
|
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||||
|
|
||||||
|
B, N, T, _ = img_seq.shape # noqa: N806
|
||||||
|
D = self.d_model # noqa: N806
|
||||||
|
device = img_seq.device
|
||||||
|
|
||||||
|
# Project inputs
|
||||||
|
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||||
|
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||||
|
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||||
|
|
||||||
|
# Concatenate streams
|
||||||
|
# cameras + lang + state -> (B, N+2, T, D)
|
||||||
|
x = torch.cat([vis_proj, lang_proj, state_proj], dim=1)
|
||||||
|
|
||||||
|
# Add positional bias to first visual frame
|
||||||
|
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||||
|
|
||||||
|
# Flatten to tokens for Transformer
|
||||||
|
x_tokens = x.view(B, (N + 2) * T, D)
|
||||||
|
L = x_tokens.size(1) # noqa: N806
|
||||||
|
|
||||||
|
# Create padding mask
|
||||||
|
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1) # (B, T)
|
||||||
|
mask = base_mask.unsqueeze(1).expand(B, N + 2, T).reshape(B, (N + 2) * T)
|
||||||
|
|
||||||
|
# Create causal mask
|
||||||
|
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||||
|
|
||||||
|
# Reshape and fuse
|
||||||
|
h = h.view(B, N + 2, T, D).permute(0, 2, 1, 3).reshape(B, T, (N + 2) * D)
|
||||||
|
fused = self.fusion_backbone(h) # (B, T, D)
|
||||||
|
|
||||||
|
# Scheme-specific logits
|
||||||
|
logits = self.heads[scheme](fused) # (B, T, num_classes)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class SubtaskTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Subtask progress regression transformer for SARM.
|
||||||
|
|
||||||
|
Predicts within-stage normalized progress (tau) conditioned on stage prior.
|
||||||
|
The stage prior is a one-hot encoding passed from StageTransformer predictions.
|
||||||
|
|
||||||
|
Input streams: [vis_proj, lang_proj, state_proj, stage_emb] -> (B, N+3, T, D)
|
||||||
|
Output: tau predictions (B, T) in [0, 1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 512,
|
||||||
|
vis_emb_dim: int = 512,
|
||||||
|
text_emb_dim: int = 512,
|
||||||
|
state_dim: int = 32,
|
||||||
|
n_layers: int = 6,
|
||||||
|
n_heads: int = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
num_cameras: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_cameras = num_cameras
|
||||||
|
|
||||||
|
# Projections
|
||||||
|
self.lang_proj = nn.Linear(text_emb_dim, d_model)
|
||||||
|
self.visual_proj = nn.Linear(vis_emb_dim, d_model)
|
||||||
|
self.state_proj = nn.Linear(state_dim, d_model)
|
||||||
|
|
||||||
|
# Encoder
|
||||||
|
enc = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout, batch_first=True)
|
||||||
|
self.transformer = nn.TransformerEncoder(enc, n_layers)
|
||||||
|
|
||||||
|
# Learned bias on first visual frame
|
||||||
|
self.first_pos = nn.Parameter(torch.zeros(1, d_model))
|
||||||
|
|
||||||
|
# Shared fusion backbone
|
||||||
|
# Fuses (num_cameras + 3) streams: cameras + lang + state + stage_emb
|
||||||
|
fused_in = d_model * (num_cameras + 3)
|
||||||
|
self.fusion_backbone = nn.Sequential(
|
||||||
|
nn.LayerNorm(fused_in),
|
||||||
|
nn.Linear(fused_in, d_model),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scheme-specific regression heads
|
||||||
|
self.heads = nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"sparse": nn.Linear(d_model, 1),
|
||||||
|
"dense": nn.Linear(d_model, 1),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prep_lang(self, lang_emb: torch.Tensor, B: int, T: int, D: int) -> torch.Tensor: # noqa: N803
|
||||||
|
"""
|
||||||
|
Prepare language embeddings for fusion.
|
||||||
|
"""
|
||||||
|
if lang_emb.dim() == 3:
|
||||||
|
# (B, T, E) -> (B, T, D) -> (B, 1, T, D)
|
||||||
|
return self.lang_proj(lang_emb).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# (B, E) -> (B, 1, 1, D) -> (B, 1, T, D)
|
||||||
|
return self.lang_proj(lang_emb).unsqueeze(1).unsqueeze(2).expand(B, 1, T, D)
|
||||||
|
|
||||||
|
def _stage_to_dmodel(self, stage_prior: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Deterministic projection of one-hot stage to d_model by pad/truncate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stage_prior: One-hot stage embedding (B, 1, T, C)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Projected stage embedding (B, 1, T, d_model)
|
||||||
|
"""
|
||||||
|
B, one, T, C = stage_prior.shape # noqa: N806
|
||||||
|
D = self.d_model # noqa: N806
|
||||||
|
if D == C:
|
||||||
|
return stage_prior
|
||||||
|
elif D > C:
|
||||||
|
pad = torch.zeros(B, one, T, D - C, device=stage_prior.device, dtype=stage_prior.dtype)
|
||||||
|
return torch.cat([stage_prior, pad], dim=-1)
|
||||||
|
else:
|
||||||
|
return stage_prior[..., :D]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img_seq: torch.Tensor, # (B, N, T, vis_emb_dim)
|
||||||
|
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||||
|
state: torch.Tensor, # (B, T, state_dim)
|
||||||
|
lengths: torch.Tensor, # (B,) - valid sequence lengths
|
||||||
|
stage_prior: torch.Tensor, # (B, 1, T, C) one-hot from gen_stage_emb
|
||||||
|
scheme: str = "sparse", # "sparse" or "dense"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for subtask progress regression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_seq: Image embeddings (B, N, T, vis_emb_dim)
|
||||||
|
lang_emb: Language embeddings (B, E) or (B, T, E)
|
||||||
|
state: State features (B, T, state_dim)
|
||||||
|
lengths: Valid sequence lengths (B,) for masking
|
||||||
|
stage_prior: One-hot stage prior (B, 1, T, num_classes)
|
||||||
|
scheme: "sparse" or "dense" for head selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tau predictions (B, T) in [0, 1] via sigmoid
|
||||||
|
"""
|
||||||
|
assert scheme in self.heads, f"Unknown scheme '{scheme}'. Use one of {list(self.heads.keys())}."
|
||||||
|
|
||||||
|
B, N, T, _ = img_seq.shape # noqa: N806
|
||||||
|
D = self.d_model # noqa: N806
|
||||||
|
device = img_seq.device
|
||||||
|
|
||||||
|
# Project inputs
|
||||||
|
vis_proj = self.visual_proj(img_seq) # (B, N, T, D)
|
||||||
|
state_proj = self.state_proj(state).unsqueeze(1) # (B, 1, T, D)
|
||||||
|
lang_proj = self._prep_lang(lang_emb, B, T, D) # (B, 1, T, D)
|
||||||
|
stage_emb = self._stage_to_dmodel(stage_prior) # (B, 1, T, D)
|
||||||
|
|
||||||
|
# Concatenate all streams
|
||||||
|
# cameras + lang + state + stage_emb -> (B, N+3, T, D)
|
||||||
|
x = torch.cat([vis_proj, lang_proj, state_proj, stage_emb], dim=1)
|
||||||
|
|
||||||
|
# Add positional bias to first visual frame
|
||||||
|
x[:, :N, 0, :] = x[:, :N, 0, :] + self.first_pos
|
||||||
|
|
||||||
|
# Flatten to tokens
|
||||||
|
x_tokens = x.view(B, (N + 3) * T, D)
|
||||||
|
L = x_tokens.size(1) # noqa: N806
|
||||||
|
|
||||||
|
# Create padding mask
|
||||||
|
base_mask = torch.arange(T, device=device).expand(B, T) >= lengths.unsqueeze(1)
|
||||||
|
mask = base_mask.unsqueeze(1).expand(B, N + 3, T).reshape(B, (N + 3) * T)
|
||||||
|
|
||||||
|
# Create causal mask
|
||||||
|
causal_mask = torch.triu(torch.ones(L, L, device=device, dtype=torch.bool), diagonal=1)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
h = self.transformer(x_tokens, mask=causal_mask, src_key_padding_mask=mask, is_causal=True)
|
||||||
|
|
||||||
|
# Reshape and fuse
|
||||||
|
h = h.view(B, N + 3, T, D)
|
||||||
|
h_flat = h.permute(0, 2, 1, 3).reshape(B, T, (N + 3) * D)
|
||||||
|
fused = self.fusion_backbone(h_flat) # (B, T, D)
|
||||||
|
|
||||||
|
# Scheme-specific regression head -> sigmoid
|
||||||
|
r = torch.sigmoid(self.heads[scheme](fused)).squeeze(-1) # (B, T)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate one-hot stage embeddings from targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes: Number of stage classes
|
||||||
|
targets: Target values (B, T) where integer part is stage index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
One-hot stage embedding (B, 1, T, num_classes)
|
||||||
|
"""
|
||||||
|
# Integer part of float targets -> [0, C-1]
|
||||||
|
idx = targets.long().clamp(min=0, max=num_classes - 1) # (B, T)
|
||||||
|
C = num_classes # noqa: N806
|
||||||
|
# Identity-lookup one-hot
|
||||||
|
stage_onehot = torch.eye(C, device=targets.device)[idx] # (B, T, C)
|
||||||
|
stage_onehot = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||||
|
return stage_onehot
|
||||||
|
|
||||||
|
|
||||||
|
class SARMRewardModel(PreTrainedPolicy):
|
||||||
|
"""
|
||||||
|
SARM Reward Model for stage-aware task completion rewards.
|
||||||
|
|
||||||
|
Uses two separate transformer models:
|
||||||
|
- StageTransformer: Classifies which stage/subtask
|
||||||
|
- SubtaskTransformer: Predicts within-stage progress (tau)
|
||||||
|
|
||||||
|
Training uses 75%/25% GT/predicted stage conditioning (teacher forcing).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "sarm"
|
||||||
|
config_class = SARMConfig
|
||||||
|
|
||||||
|
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||||
|
super().__init__(config, dataset_stats)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
self.dataset_stats = dataset_stats
|
||||||
|
self.device = torch.device(
|
||||||
|
config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load temporal proportions based on annotation_mode
|
||||||
|
if config.annotation_mode == "single_stage":
|
||||||
|
logging.info(f"Using single_stage mode: sparse_subtask_names={config.sparse_subtask_names}")
|
||||||
|
elif dataset_meta is not None:
|
||||||
|
self._load_temporal_proportions(dataset_meta)
|
||||||
|
|
||||||
|
# Create two separate models
|
||||||
|
self.stage_model = StageTransformer(
|
||||||
|
d_model=config.hidden_dim,
|
||||||
|
vis_emb_dim=config.image_dim,
|
||||||
|
text_emb_dim=config.text_dim,
|
||||||
|
state_dim=config.max_state_dim,
|
||||||
|
n_layers=config.num_layers,
|
||||||
|
n_heads=config.num_heads,
|
||||||
|
dropout=config.dropout,
|
||||||
|
num_cameras=1, # Single camera for now
|
||||||
|
num_classes_sparse=config.num_sparse_stages,
|
||||||
|
num_classes_dense=config.num_dense_stages or config.num_sparse_stages,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.subtask_model = SubtaskTransformer(
|
||||||
|
d_model=config.hidden_dim,
|
||||||
|
vis_emb_dim=config.image_dim,
|
||||||
|
text_emb_dim=config.text_dim,
|
||||||
|
state_dim=config.max_state_dim,
|
||||||
|
n_layers=config.num_layers,
|
||||||
|
n_heads=config.num_heads,
|
||||||
|
dropout=config.dropout,
|
||||||
|
num_cameras=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stage_model.to(self.device)
|
||||||
|
self.subtask_model.to(self.device)
|
||||||
|
|
||||||
|
# GT/predicted stage ratio for teacher forcing
|
||||||
|
self.gt_stage_ratio = 0.75
|
||||||
|
|
||||||
|
if config.uses_dual_heads:
|
||||||
|
logging.info(
|
||||||
|
f"SARM initialized with dual heads: {config.num_sparse_stages} sparse stages, "
|
||||||
|
f"{config.num_dense_stages} dense stages"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"SARM initialized with sparse head only: {config.num_sparse_stages} stages")
|
||||||
|
|
||||||
|
logging.info(f"SARM initialized on {self.device}")
|
||||||
|
|
||||||
|
def _load_proportions_from_json(self, path, annotation_type: str) -> tuple[list[str], list[float]]:
|
||||||
|
"""Load temporal proportions from a JSON file (preserving order)."""
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"{annotation_type.capitalize()} temporal proportions not found at {path}. "
|
||||||
|
f"Run the subtask annotation tool with --{annotation_type}-subtasks to generate annotations."
|
||||||
|
)
|
||||||
|
with open(path) as f:
|
||||||
|
proportions_dict = json.load(f)
|
||||||
|
names = list(proportions_dict.keys())
|
||||||
|
logging.info(f"Loaded {len(names)} {annotation_type} subtasks: {names}")
|
||||||
|
logging.info(f"{annotation_type.capitalize()} temporal proportions: {proportions_dict}")
|
||||||
|
return names, [proportions_dict[name] for name in names]
|
||||||
|
|
||||||
|
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||||
|
"""Load temporal proportions based on annotation_mode."""
|
||||||
|
meta_path = dataset_meta.root / "meta"
|
||||||
|
|
||||||
|
if self.config.annotation_mode == "dual":
|
||||||
|
names, props = self._load_proportions_from_json(
|
||||||
|
meta_path / "temporal_proportions_sparse.json", "sparse"
|
||||||
|
)
|
||||||
|
(
|
||||||
|
self.config.num_sparse_stages,
|
||||||
|
self.config.sparse_subtask_names,
|
||||||
|
self.config.sparse_temporal_proportions,
|
||||||
|
) = len(names), names, props
|
||||||
|
|
||||||
|
if self.config.annotation_mode in ["dense_only", "dual"]:
|
||||||
|
names, props = self._load_proportions_from_json(
|
||||||
|
meta_path / "temporal_proportions_dense.json", "dense"
|
||||||
|
)
|
||||||
|
(
|
||||||
|
self.config.num_dense_stages,
|
||||||
|
self.config.dense_subtask_names,
|
||||||
|
self.config.dense_temporal_proportions,
|
||||||
|
) = len(names), names, props
|
||||||
|
if self.config.annotation_mode == "dense_only":
|
||||||
|
logging.info(f"Using auto-generated sparse 'task' stage: {self.config.sparse_subtask_names}")
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
"""Override to method to ensure all components move together."""
|
||||||
|
super().to(device)
|
||||||
|
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||||
|
self.stage_model.to(device)
|
||||||
|
self.subtask_model.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def calculate_rewards(
|
||||||
|
self,
|
||||||
|
text_embeddings: np.ndarray | torch.Tensor,
|
||||||
|
video_embeddings: np.ndarray | torch.Tensor,
|
||||||
|
state_features: np.ndarray | torch.Tensor | None = None,
|
||||||
|
lengths: np.ndarray | torch.Tensor | None = None,
|
||||||
|
return_all_frames: bool = False,
|
||||||
|
return_stages: bool = False,
|
||||||
|
return_confidence: bool = False,
|
||||||
|
head_mode: str | None = "sparse",
|
||||||
|
frame_index: int | None = None,
|
||||||
|
) -> np.ndarray | tuple:
|
||||||
|
"""
|
||||||
|
Calculate rewards for given text, video, and state representations.
|
||||||
|
|
||||||
|
This is the canonical method for SARM reward computation, used for:
|
||||||
|
- Inference/visualization
|
||||||
|
- RA-BC weight computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_embeddings: Encoded text representations (batch_size, 512)
|
||||||
|
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||||
|
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||||
|
lengths: Valid sequence lengths (batch_size,)
|
||||||
|
return_all_frames: If True, return rewards for all frames
|
||||||
|
return_stages: If True, also return stage predictions
|
||||||
|
return_confidence: If True, also return stage confidence
|
||||||
|
head_mode: Which head to use ("sparse" or "dense")
|
||||||
|
frame_index: Index of the target frame to extract (default: n_obs_steps).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rewards and optionally stage probs/confidence.
|
||||||
|
"""
|
||||||
|
if isinstance(text_embeddings, np.ndarray):
|
||||||
|
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||||
|
if isinstance(video_embeddings, np.ndarray):
|
||||||
|
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||||
|
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||||
|
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Handle single sample case
|
||||||
|
if text_embeddings.dim() == 1:
|
||||||
|
text_embeddings = text_embeddings.unsqueeze(0)
|
||||||
|
video_embeddings = video_embeddings.unsqueeze(0)
|
||||||
|
if state_features is not None:
|
||||||
|
state_features = state_features.unsqueeze(0)
|
||||||
|
single_sample = True
|
||||||
|
else:
|
||||||
|
single_sample = False
|
||||||
|
|
||||||
|
batch_size = video_embeddings.shape[0]
|
||||||
|
seq_len = video_embeddings.shape[1]
|
||||||
|
|
||||||
|
scheme = head_mode
|
||||||
|
|
||||||
|
# Default lengths if not provided
|
||||||
|
if lengths is None:
|
||||||
|
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32)
|
||||||
|
elif isinstance(lengths, np.ndarray):
|
||||||
|
lengths = torch.tensor(lengths, dtype=torch.int32)
|
||||||
|
|
||||||
|
# Reshape video to (B, N, T, D) for multi-camera format
|
||||||
|
# Currently single camera: (B, T, D) -> (B, 1, T, D)
|
||||||
|
img_seq = video_embeddings.unsqueeze(1).to(self.device)
|
||||||
|
lang_emb = text_embeddings.to(self.device)
|
||||||
|
state = (
|
||||||
|
state_features.to(self.device)
|
||||||
|
if state_features is not None
|
||||||
|
else torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||||
|
)
|
||||||
|
lens = lengths.to(self.device)
|
||||||
|
|
||||||
|
# Pad state to max_state_dim
|
||||||
|
state = pad_state_to_max_dim(state, self.config.max_state_dim)
|
||||||
|
|
||||||
|
# Get num_classes for this scheme
|
||||||
|
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||||
|
|
||||||
|
# Run stage model
|
||||||
|
stage_logits = self.stage_model(img_seq, lang_emb, state, lens, scheme=scheme)
|
||||||
|
stage_probs = F.softmax(stage_logits, dim=-1) # (B, T, num_classes)
|
||||||
|
stage_idx = stage_probs.argmax(dim=-1) # (B, T)
|
||||||
|
stage_conf = stage_probs.gather(-1, stage_idx.unsqueeze(-1)).squeeze(-1) # (B, T)
|
||||||
|
|
||||||
|
# Create one-hot stage prior
|
||||||
|
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||||
|
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||||
|
|
||||||
|
# Run subtask model
|
||||||
|
tau_pred = self.subtask_model(img_seq, lang_emb, state, lens, stage_emb, scheme=scheme)
|
||||||
|
|
||||||
|
# Compute final reward: stage + tau
|
||||||
|
raw_reward = stage_idx.float() + tau_pred # (B, T)
|
||||||
|
|
||||||
|
# Normalize to [0, 1] using temporal proportions for proper weighting
|
||||||
|
if scheme == "sparse":
|
||||||
|
normalized_reward = normalize_stage_tau(
|
||||||
|
raw_reward,
|
||||||
|
num_stages=num_classes,
|
||||||
|
temporal_proportions=self.config.sparse_temporal_proportions,
|
||||||
|
subtask_names=self.config.sparse_subtask_names,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
normalized_reward = normalize_stage_tau(
|
||||||
|
raw_reward,
|
||||||
|
num_stages=num_classes,
|
||||||
|
temporal_proportions=self.config.dense_temporal_proportions,
|
||||||
|
subtask_names=self.config.dense_subtask_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default frame index is n_obs_steps (last observation frame)
|
||||||
|
if frame_index is None:
|
||||||
|
frame_index = self.config.n_obs_steps
|
||||||
|
|
||||||
|
# Prepare outputs (batch mode or no smoothing)
|
||||||
|
if return_all_frames:
|
||||||
|
rewards = normalized_reward.cpu().numpy()
|
||||||
|
else:
|
||||||
|
rewards = normalized_reward[:, frame_index].cpu().numpy()
|
||||||
|
|
||||||
|
if single_sample:
|
||||||
|
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||||
|
|
||||||
|
outputs = [rewards]
|
||||||
|
if return_stages:
|
||||||
|
probs = stage_probs.cpu().numpy()
|
||||||
|
if single_sample:
|
||||||
|
probs = probs[0]
|
||||||
|
outputs.append(probs)
|
||||||
|
if return_confidence:
|
||||||
|
conf = stage_conf.cpu().numpy()
|
||||||
|
if single_sample:
|
||||||
|
conf = conf[0]
|
||||||
|
outputs.append(conf)
|
||||||
|
|
||||||
|
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
"""Set training mode for both models."""
|
||||||
|
super().train(mode)
|
||||||
|
self.stage_model.train(mode)
|
||||||
|
self.subtask_model.train(mode)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""Set evaluation mode for both models."""
|
||||||
|
return self.train(False)
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
"""Override to return trainable parameters from both models."""
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
return chain(self.stage_model.parameters(), self.subtask_model.parameters())
|
||||||
|
|
||||||
|
def get_optim_params(self):
|
||||||
|
"""Override to return optimizer parameters from both models."""
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||||
|
raise NotImplementedError("SARM model does not predict action chunks")
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||||
|
raise NotImplementedError("SARM model does not select actions")
|
||||||
|
|
||||||
|
def _train_step(
|
||||||
|
self,
|
||||||
|
img_emb: torch.Tensor, # (B, N, T, D)
|
||||||
|
lang_emb: torch.Tensor, # (B, E) or (B, T, E)
|
||||||
|
state: torch.Tensor, # (B, T, state_dim)
|
||||||
|
lengths: torch.Tensor, # (B,)
|
||||||
|
targets: torch.Tensor, # (B, T) - format: stage.tau
|
||||||
|
scheme: str,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Single training step for one annotation scheme.
|
||||||
|
|
||||||
|
Implements 75%/25% GT/predicted stage conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_emb: Image embeddings (B, N, T, D)
|
||||||
|
lang_emb: Language embeddings
|
||||||
|
state: State features
|
||||||
|
lengths: Valid sequence lengths
|
||||||
|
targets: Target values where floor=stage, remainder=tau
|
||||||
|
scheme: "sparse" or "dense"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with stage_loss, subtask_loss, total_loss
|
||||||
|
"""
|
||||||
|
num_classes = self.config.num_sparse_stages if scheme == "sparse" else self.config.num_dense_stages
|
||||||
|
|
||||||
|
# Ground truth: stage (integer) and tau (fractional)
|
||||||
|
# Clamp stage indices to valid range [0, num_classes-1] to handle edge cases
|
||||||
|
# where targets may exceed expected range (e.g., frames between subtasks)
|
||||||
|
gt_stage = torch.floor(targets).long().clamp(0, num_classes - 1) # (B, T)
|
||||||
|
gt_tau = torch.remainder(targets, 1.0) # (B, T)
|
||||||
|
|
||||||
|
# Run stage model
|
||||||
|
stage_pred = self.stage_model(img_emb, lang_emb, state, lengths, scheme=scheme)
|
||||||
|
|
||||||
|
# 75%/25% GT/predicted stage conditioning
|
||||||
|
if random.random() < self.gt_stage_ratio:
|
||||||
|
# Mode 1: Use ground truth stage -> one-hot
|
||||||
|
stage_emb = gen_stage_emb(num_classes, targets) # (B, 1, T, C)
|
||||||
|
else:
|
||||||
|
# Mode 2: Use predicted stage argmax -> one-hot
|
||||||
|
stage_idx = stage_pred.argmax(dim=-1) # (B, T)
|
||||||
|
stage_onehot = F.one_hot(stage_idx, num_classes=num_classes).float() # (B, T, C)
|
||||||
|
stage_emb = stage_onehot.unsqueeze(1) # (B, 1, T, C)
|
||||||
|
|
||||||
|
# Run subtask model with stage prior
|
||||||
|
tau_pred = self.subtask_model(img_emb, lang_emb, state, lengths, stage_emb, scheme=scheme)
|
||||||
|
|
||||||
|
# Compute losses
|
||||||
|
stage_loss = F.cross_entropy(stage_pred.view(-1, num_classes), gt_stage.view(-1), reduction="mean")
|
||||||
|
subtask_loss = F.mse_loss(tau_pred, gt_tau, reduction="mean")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stage_loss": stage_loss,
|
||||||
|
"subtask_loss": subtask_loss,
|
||||||
|
"total_loss": stage_loss + subtask_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
"""
|
||||||
|
Forward pass for SARM reward model training.
|
||||||
|
|
||||||
|
Uses stage+tau target format where:
|
||||||
|
- Integer part = stage index
|
||||||
|
- Fractional part = within-stage progress (tau)
|
||||||
|
|
||||||
|
Training uses 75%/25% GT/predicted stage conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Dictionary with 'observation' containing:
|
||||||
|
- 'video_features': (B, T, 512) pre-encoded video features
|
||||||
|
- 'text_features': (B, 512) or (B, T, 512) text features
|
||||||
|
- 'state_features': (B, T, state_dim) joint state features
|
||||||
|
- 'lengths': (B,) valid sequence lengths
|
||||||
|
- 'sparse_targets': (B, T) sparse targets (stage.tau format)
|
||||||
|
- 'dense_targets': (B, T) dense targets (optional, for dual mode)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (total_loss, output_dict with loss components)
|
||||||
|
"""
|
||||||
|
observation = batch.get("observation", batch)
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
video_features = observation["video_features"].to(self.device)
|
||||||
|
text_features = observation["text_features"].to(self.device)
|
||||||
|
state_features = observation.get("state_features")
|
||||||
|
if state_features is not None:
|
||||||
|
state_features = state_features.to(self.device)
|
||||||
|
|
||||||
|
batch_size = video_features.shape[0]
|
||||||
|
seq_len = video_features.shape[1]
|
||||||
|
|
||||||
|
# Get lengths (default to full sequence)
|
||||||
|
lengths = observation.get("lengths")
|
||||||
|
if lengths is None:
|
||||||
|
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device=self.device)
|
||||||
|
else:
|
||||||
|
lengths = lengths.to(self.device)
|
||||||
|
|
||||||
|
# Reshape video to (B, N, T, D) - single camera
|
||||||
|
img_emb = video_features.unsqueeze(1)
|
||||||
|
|
||||||
|
# Pad state to max_state_dim
|
||||||
|
if state_features is None:
|
||||||
|
state_features = torch.zeros(batch_size, seq_len, self.config.max_state_dim, device=self.device)
|
||||||
|
else:
|
||||||
|
state_features = pad_state_to_max_dim(state_features, self.config.max_state_dim)
|
||||||
|
|
||||||
|
output_dict = {}
|
||||||
|
total_loss = torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
|
# Sparse training (always)
|
||||||
|
sparse_targets = observation.get("sparse_targets")
|
||||||
|
if sparse_targets is None:
|
||||||
|
# Try legacy format
|
||||||
|
sparse_targets = observation.get("targets")
|
||||||
|
if sparse_targets is None:
|
||||||
|
raise ValueError("sparse_targets (or targets) is required for SARM training")
|
||||||
|
sparse_targets = sparse_targets.to(self.device)
|
||||||
|
|
||||||
|
sparse_result = self._train_step(
|
||||||
|
img_emb, text_features, state_features, lengths, sparse_targets, scheme="sparse"
|
||||||
|
)
|
||||||
|
output_dict["sparse_stage_loss"] = sparse_result["stage_loss"].item()
|
||||||
|
output_dict["sparse_subtask_loss"] = sparse_result["subtask_loss"].item()
|
||||||
|
total_loss = total_loss + sparse_result["total_loss"]
|
||||||
|
|
||||||
|
# Dense training (if dual mode)
|
||||||
|
if self.config.uses_dual_heads:
|
||||||
|
dense_targets = observation.get("dense_targets")
|
||||||
|
if dense_targets is not None:
|
||||||
|
dense_targets = dense_targets.to(self.device)
|
||||||
|
dense_result = self._train_step(
|
||||||
|
img_emb, text_features, state_features, lengths, dense_targets, scheme="dense"
|
||||||
|
)
|
||||||
|
output_dict["dense_stage_loss"] = dense_result["stage_loss"].item()
|
||||||
|
output_dict["dense_subtask_loss"] = dense_result["subtask_loss"].item()
|
||||||
|
total_loss = total_loss + dense_result["total_loss"]
|
||||||
|
|
||||||
|
output_dict["total_loss"] = total_loss.item()
|
||||||
|
return total_loss, output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute cross-entropy loss for stage classification."""
|
||||||
|
_, _, num_stages = stage_logits.shape
|
||||||
|
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||||
|
# Clamp target stage indices to valid range [0, num_stages-1]
|
||||||
|
target_stages_flat = target_stages.reshape(-1).clamp(0, num_stages - 1)
|
||||||
|
return F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||||
@@ -0,0 +1,518 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""SARM Processor for encoding images/text and generating stage+tau targets."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from faker import Faker
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||||
|
from lerobot.policies.sarm.sarm_utils import (
|
||||||
|
apply_rewind_augmentation,
|
||||||
|
compute_absolute_indices,
|
||||||
|
find_stage_and_tau,
|
||||||
|
pad_state_to_max_dim,
|
||||||
|
)
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import (
|
||||||
|
from_tensor_to_numpy,
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_policy_action,
|
||||||
|
)
|
||||||
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
|
from lerobot.processor.pipeline import PipelineFeatureType
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
|
class SARMEncodingProcessorStep(ProcessorStep):
|
||||||
|
"""ProcessorStep that encodes images and text with CLIP and generates stage and progress labels for SARM."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: SARMConfig,
|
||||||
|
image_key: str | None = None,
|
||||||
|
dataset_meta=None,
|
||||||
|
dataset_stats: dict | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.image_key = image_key or config.image_key
|
||||||
|
self.dataset_meta = dataset_meta
|
||||||
|
self.dataset_stats = dataset_stats
|
||||||
|
self.annotation_mode = config.annotation_mode
|
||||||
|
|
||||||
|
# Helper to create temporal proportions dict
|
||||||
|
def make_props_dict(names, props):
|
||||||
|
return dict(zip(names, props, strict=True)) if names and props else None
|
||||||
|
|
||||||
|
# Sparse annotations (always needed)
|
||||||
|
self.sparse_temporal_proportions = make_props_dict(
|
||||||
|
config.sparse_subtask_names, config.sparse_temporal_proportions
|
||||||
|
)
|
||||||
|
self.sparse_subtask_names = config.sparse_subtask_names
|
||||||
|
|
||||||
|
# Dense annotations (only for dual mode)
|
||||||
|
self.dense_subtask_names = config.dense_subtask_names if config.uses_dual_heads else None
|
||||||
|
self.dense_temporal_proportions = (
|
||||||
|
make_props_dict(config.dense_subtask_names, config.dense_temporal_proportions)
|
||||||
|
if config.uses_dual_heads
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device = torch.device(
|
||||||
|
self.config.device if self.config.device else "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||||
|
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||||
|
self.clip_model.to(self.device)
|
||||||
|
self.clip_model.eval()
|
||||||
|
|
||||||
|
self.verbs = ["move", "grasp", "rotate", "push", "pull", "slide", "lift", "place"]
|
||||||
|
self.fake = Faker()
|
||||||
|
|
||||||
|
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||||
|
"""Find the episode index for a given frame index."""
|
||||||
|
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||||
|
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||||
|
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||||
|
if ep_start <= frame_idx < ep_end:
|
||||||
|
return ep_idx
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||||
|
"""Get episode indices for each frame index."""
|
||||||
|
if episode_index is None:
|
||||||
|
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||||
|
|
||||||
|
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||||
|
|
||||||
|
# If single episode but multiple frames, compute episode for each frame
|
||||||
|
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||||
|
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||||
|
|
||||||
|
return episode_indices
|
||||||
|
|
||||||
|
def _generate_perturbed_task(self) -> str:
|
||||||
|
"""Generate a random perturbed task string for language perturbation."""
|
||||||
|
num_words = random.randint(1, 5)
|
||||||
|
verb = random.choice(self.verbs)
|
||||||
|
phrase = " ".join([verb] + self.fake.words(nb=num_words))
|
||||||
|
return phrase
|
||||||
|
|
||||||
|
def _get_annotation_config(self, annotation_type: str) -> tuple[list[str], dict[str, float] | None]:
|
||||||
|
"""Get global subtask names and temporal proportions for an annotation type."""
|
||||||
|
if annotation_type == "dense":
|
||||||
|
return self.dense_subtask_names, self.dense_temporal_proportions
|
||||||
|
return self.sparse_subtask_names, self.sparse_temporal_proportions
|
||||||
|
|
||||||
|
def _load_episode_annotations(
|
||||||
|
self,
|
||||||
|
ep_idx: int,
|
||||||
|
episodes_df: pd.DataFrame | None,
|
||||||
|
annotation_type: str,
|
||||||
|
global_names: list[str],
|
||||||
|
) -> tuple[list | None, list | None, list | None]:
|
||||||
|
"""Load subtask annotations for an episode from DataFrame."""
|
||||||
|
# Single-stage mode: (linear progress 0→1)
|
||||||
|
if episodes_df is None or len(global_names) == 1:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
# Resolve column name with fallback
|
||||||
|
def col(suffix):
|
||||||
|
prefixed = f"{annotation_type}_{suffix}"
|
||||||
|
return prefixed if prefixed in episodes_df.columns else suffix
|
||||||
|
|
||||||
|
col_names = col("subtask_names")
|
||||||
|
if col_names not in episodes_df.columns or ep_idx >= len(episodes_df):
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
subtask_names = episodes_df.loc[ep_idx, col_names]
|
||||||
|
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
return (
|
||||||
|
subtask_names,
|
||||||
|
episodes_df.loc[ep_idx, col("subtask_start_frames")],
|
||||||
|
episodes_df.loc[ep_idx, col("subtask_end_frames")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""
|
||||||
|
Encode images, text, and normalize states in the transition.
|
||||||
|
|
||||||
|
Implements SARM training data preparation:
|
||||||
|
- Applies language perturbation (20% probability)
|
||||||
|
- Applies rewind augmentation (80% probability)
|
||||||
|
- Generates stage+tau targets for all frames
|
||||||
|
- Outputs lengths tensor for valid sequence masking
|
||||||
|
"""
|
||||||
|
new_transition = transition.copy() if hasattr(transition, "copy") else dict(transition)
|
||||||
|
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||||
|
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
|
||||||
|
frame_index = comp_data.get("index")
|
||||||
|
episode_index = comp_data.get("episode_index")
|
||||||
|
|
||||||
|
if frame_index is None:
|
||||||
|
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||||
|
if episode_index is None:
|
||||||
|
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||||
|
|
||||||
|
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||||
|
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||||
|
|
||||||
|
image = observation.get(self.image_key)
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
image = image.cpu().numpy()
|
||||||
|
|
||||||
|
# If 4D (T, C, H, W) from delta_timestamps, add batch dim
|
||||||
|
# If 3D (C, H, W) single frame, add batch and time dims
|
||||||
|
if image.ndim == 4:
|
||||||
|
image = image[np.newaxis, ...] # (T, C, H, W) -> (1, T, C, H, W)
|
||||||
|
elif image.ndim == 3:
|
||||||
|
image = image[np.newaxis, np.newaxis, ...] # (C, H, W) -> (1, 1, C, H, W)
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
total_frames = image.shape[1] # Should be 13: 9 obs + 4 rewind placeholders
|
||||||
|
n_obs_steps = self.config.n_obs_steps
|
||||||
|
max_rewind_steps = self.config.max_rewind_steps
|
||||||
|
n_obs_frames = 1 + n_obs_steps # 9 observation frames (including current)
|
||||||
|
|
||||||
|
# Rewind augmentation
|
||||||
|
rewind_steps = torch.zeros(batch_size, dtype=torch.int32)
|
||||||
|
apply_rewind = self.training and random.random() < self.config.rewind_probability
|
||||||
|
|
||||||
|
if apply_rewind and self.dataset_meta is not None:
|
||||||
|
for b_idx, (ep_idx, frame_idx) in enumerate(
|
||||||
|
zip(episode_indices.tolist(), frame_indices.tolist(), strict=True)
|
||||||
|
):
|
||||||
|
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||||
|
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||||
|
|
||||||
|
rewind_step, _ = apply_rewind_augmentation(
|
||||||
|
frame_idx, ep_start, n_obs_steps, max_rewind_steps, frame_gap=self.config.frame_gap
|
||||||
|
)
|
||||||
|
rewind_steps[b_idx] = rewind_step
|
||||||
|
|
||||||
|
# Compute valid lengths: n_obs_frames + rewind_steps
|
||||||
|
lengths = n_obs_frames + rewind_steps # (B,)
|
||||||
|
|
||||||
|
# Apply rewind masking to images
|
||||||
|
# For frames beyond valid length, we mask with zeros (or copy last valid frame)
|
||||||
|
for b_idx in range(batch_size):
|
||||||
|
valid_len = lengths[b_idx].item()
|
||||||
|
if valid_len < total_frames:
|
||||||
|
image[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||||
|
|
||||||
|
# Encode images with CLIP
|
||||||
|
video_features = self._encode_images_batch(image)
|
||||||
|
observation["video_features"] = video_features
|
||||||
|
|
||||||
|
state_key = self.config.state_key
|
||||||
|
state_data = observation.get(state_key)
|
||||||
|
|
||||||
|
if isinstance(state_data, torch.Tensor):
|
||||||
|
state_tensor = state_data.float()
|
||||||
|
else:
|
||||||
|
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||||
|
|
||||||
|
if state_tensor.ndim == 2:
|
||||||
|
state_tensor = state_tensor.unsqueeze(0) # (T, D) -> (1, T, D)
|
||||||
|
elif state_tensor.ndim == 1:
|
||||||
|
state_tensor = state_tensor.unsqueeze(0).unsqueeze(0) # (D,) -> (1, 1, D)
|
||||||
|
|
||||||
|
# Apply same rewind masking to state
|
||||||
|
for b_idx in range(batch_size):
|
||||||
|
valid_len = lengths[b_idx].item()
|
||||||
|
if valid_len < state_tensor.shape[1]:
|
||||||
|
state_tensor[b_idx, valid_len:] = 0 # Zero out frames beyond valid length
|
||||||
|
|
||||||
|
observation["state_features"] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||||
|
|
||||||
|
task = comp_data.get("task")
|
||||||
|
if isinstance(task, list):
|
||||||
|
task = task[0] if task else ""
|
||||||
|
|
||||||
|
# Apply language perturbation during training (20% probability)
|
||||||
|
# When perturbed, targets will be zeroed to train model to output low values for irrelevant text
|
||||||
|
apply_perturbation = self.training and random.random() < self.config.language_perturbation_probability
|
||||||
|
if apply_perturbation:
|
||||||
|
task = self._generate_perturbed_task()
|
||||||
|
|
||||||
|
# Encode text with CLIP
|
||||||
|
observation["text_features"] = self._encode_text_clip(task, batch_size)
|
||||||
|
|
||||||
|
# Store lengths for model
|
||||||
|
observation["lengths"] = lengths
|
||||||
|
|
||||||
|
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||||
|
if self.dataset_meta is not None:
|
||||||
|
episodes_df = None
|
||||||
|
if self.sparse_subtask_names != ["task"]:
|
||||||
|
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||||
|
|
||||||
|
# Generate sparse targets
|
||||||
|
if self.sparse_temporal_proportions is not None:
|
||||||
|
if apply_perturbation:
|
||||||
|
# Zero targets when language is perturbed
|
||||||
|
sparse_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
sparse_targets = self._compute_batch_targets(
|
||||||
|
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "sparse"
|
||||||
|
)
|
||||||
|
observation["sparse_targets"] = sparse_targets
|
||||||
|
|
||||||
|
# Generate dense targets (for dual mode)
|
||||||
|
if self.config.uses_dual_heads and self.dense_temporal_proportions is not None:
|
||||||
|
if apply_perturbation:
|
||||||
|
# Zero targets when language is perturbed
|
||||||
|
dense_targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
dense_targets = self._compute_batch_targets(
|
||||||
|
frame_indices, episode_indices, lengths, rewind_steps, episodes_df, "dense"
|
||||||
|
)
|
||||||
|
observation["dense_targets"] = dense_targets
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def _compute_batch_targets(
|
||||||
|
self,
|
||||||
|
frame_indices: np.ndarray,
|
||||||
|
episode_indices: np.ndarray,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
rewind_steps: torch.Tensor,
|
||||||
|
episodes_df: pd.DataFrame | None,
|
||||||
|
annotation_type: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute stage+tau targets for a batch of samples."""
|
||||||
|
batch_size = len(frame_indices)
|
||||||
|
n_obs_steps = self.config.n_obs_steps
|
||||||
|
max_rewind_steps = self.config.max_rewind_steps
|
||||||
|
total_frames = 1 + n_obs_steps + max_rewind_steps
|
||||||
|
frame_gap = self.config.frame_gap
|
||||||
|
|
||||||
|
global_names, temporal_props = self._get_annotation_config(annotation_type)
|
||||||
|
targets = torch.zeros(batch_size, total_frames, dtype=torch.float32)
|
||||||
|
|
||||||
|
for b_idx in range(batch_size):
|
||||||
|
ep_idx = int(episode_indices[b_idx])
|
||||||
|
frame_idx = int(frame_indices[b_idx])
|
||||||
|
|
||||||
|
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||||
|
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||||
|
ep_length = ep_end - ep_start
|
||||||
|
|
||||||
|
subtask_names, subtask_start_frames, subtask_end_frames = self._load_episode_annotations(
|
||||||
|
ep_idx, episodes_df, annotation_type, global_names
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute observation frame indices
|
||||||
|
obs_indices, _ = compute_absolute_indices(
|
||||||
|
frame_idx, ep_start, ep_end, n_obs_steps, frame_gap=frame_gap
|
||||||
|
)
|
||||||
|
obs_indices = obs_indices.tolist()
|
||||||
|
|
||||||
|
# Compute targets for observation frames
|
||||||
|
for t_idx, abs_idx in enumerate(obs_indices):
|
||||||
|
rel_frame = abs_idx - ep_start
|
||||||
|
targets[b_idx, t_idx] = find_stage_and_tau(
|
||||||
|
rel_frame,
|
||||||
|
ep_length,
|
||||||
|
subtask_names,
|
||||||
|
subtask_start_frames,
|
||||||
|
subtask_end_frames,
|
||||||
|
global_names,
|
||||||
|
temporal_props,
|
||||||
|
return_combined=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute targets for rewind frames (if any)
|
||||||
|
rewind_step = rewind_steps[b_idx].item()
|
||||||
|
if rewind_step > 0:
|
||||||
|
_, rewind_indices = apply_rewind_augmentation(
|
||||||
|
frame_idx,
|
||||||
|
ep_start,
|
||||||
|
n_obs_steps,
|
||||||
|
max_rewind_steps,
|
||||||
|
frame_gap=frame_gap,
|
||||||
|
rewind_step=rewind_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
for r_idx, abs_idx in enumerate(rewind_indices[:rewind_step]):
|
||||||
|
rel_frame = max(0, abs_idx - ep_start)
|
||||||
|
targets[b_idx, n_obs_steps + 1 + r_idx] = find_stage_and_tau(
|
||||||
|
rel_frame,
|
||||||
|
ep_length,
|
||||||
|
subtask_names,
|
||||||
|
subtask_start_frames,
|
||||||
|
subtask_end_frames,
|
||||||
|
global_names,
|
||||||
|
temporal_props,
|
||||||
|
return_combined=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return targets
|
||||||
|
|
||||||
|
@property
|
||||||
|
def training(self) -> bool:
|
||||||
|
return getattr(self, "_training_mode", True)
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
"""Set training mode for augmentation decisions."""
|
||||||
|
self._training_mode = mode
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
"""Set evaluation mode (disable augmentations)."""
|
||||||
|
return self.train(False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Encode a batch of images using CLIP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: Batched images with shape: (B, T, C, H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded feature vectors with shape (B, T, 512)
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||||
|
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||||
|
|
||||||
|
num_frames = images.shape[0]
|
||||||
|
images_list = []
|
||||||
|
for i in range(num_frames):
|
||||||
|
img = images[i]
|
||||||
|
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||||
|
img = img.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Handle single channel
|
||||||
|
if img.shape[-1] == 1:
|
||||||
|
img = np.repeat(img, 3, axis=-1)
|
||||||
|
|
||||||
|
if img.dtype != np.uint8:
|
||||||
|
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||||
|
|
||||||
|
images_list.append(Image.fromarray(img))
|
||||||
|
|
||||||
|
all_embeddings = []
|
||||||
|
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||||
|
batch_imgs = images_list[i : i + self.config.clip_batch_size]
|
||||||
|
|
||||||
|
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Get image embeddings
|
||||||
|
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||||
|
|
||||||
|
# Handle single frame case
|
||||||
|
if embeddings.dim() == 1:
|
||||||
|
embeddings = embeddings.unsqueeze(0)
|
||||||
|
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||||
|
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||||
|
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Task description text to encode
|
||||||
|
batch_size: Batch size to replicate for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded text features with shape (B, 512)
|
||||||
|
"""
|
||||||
|
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||||
|
text_embedding = text_embedding.expand(batch_size, -1)
|
||||||
|
|
||||||
|
return text_embedding
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""Add encoded features to the observation features."""
|
||||||
|
features[PipelineFeatureType.OBSERVATION]["video_features"] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL, shape=(self.config.num_frames, self.config.image_dim)
|
||||||
|
)
|
||||||
|
features[PipelineFeatureType.OBSERVATION]["text_features"] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.config.text_dim,)
|
||||||
|
)
|
||||||
|
features[PipelineFeatureType.OBSERVATION]["state_features"] = PolicyFeature(
|
||||||
|
type=FeatureType.STATE, shape=(self.config.num_frames, self.config.max_state_dim)
|
||||||
|
)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def make_sarm_pre_post_processors(
|
||||||
|
config: SARMConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
dataset_meta=None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Create pre-processor and post-processor pipelines for SARM."""
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=[
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
SARMEncodingProcessorStep(
|
||||||
|
config=config, dataset_meta=dataset_meta, dataset_stats=dataset_stats
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
],
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=[DeviceProcessorStep(device="cpu")],
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -0,0 +1,295 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
|
def find_stage_and_tau(
|
||||||
|
current_frame: int,
|
||||||
|
episode_length: int,
|
||||||
|
subtask_names: list | None,
|
||||||
|
subtask_start_frames: list | None,
|
||||||
|
subtask_end_frames: list | None,
|
||||||
|
global_subtask_names: list,
|
||||||
|
temporal_proportions: dict,
|
||||||
|
return_combined: bool = False,
|
||||||
|
) -> tuple[int, float] | float:
|
||||||
|
"""Find stage and within-stage progress (tau) for a frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_frame: Frame index relative to episode start
|
||||||
|
episode_length: Total frames in episode
|
||||||
|
subtask_names: Subtask names for this episode (None for single_stage)
|
||||||
|
subtask_start_frames: Subtask start frames
|
||||||
|
subtask_end_frames: Subtask end frames
|
||||||
|
global_subtask_names: Global list of all subtask names
|
||||||
|
temporal_proportions: Dict of temporal proportions
|
||||||
|
return_combined: If True, return stage+tau as float; else (stage_idx, tau) tuple
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Float (stage.tau) if return_combined, else (stage_idx, tau) tuple
|
||||||
|
"""
|
||||||
|
stage_idx, tau = 0, 0.0
|
||||||
|
num_stages = len(global_subtask_names)
|
||||||
|
|
||||||
|
# Single-stage mode: linear progress from 0 to 1
|
||||||
|
if num_stages == 1:
|
||||||
|
tau = min(1.0, max(0.0, current_frame / max(episode_length - 1, 1)))
|
||||||
|
elif subtask_names is None:
|
||||||
|
pass # stage_idx=0, tau=0.0
|
||||||
|
elif current_frame < subtask_start_frames[0]:
|
||||||
|
pass # Before first subtask: stage_idx=0, tau=0.0
|
||||||
|
elif current_frame > subtask_end_frames[-1]:
|
||||||
|
stage_idx, tau = num_stages - 1, 0.999 # After last subtask
|
||||||
|
else:
|
||||||
|
# Find which subtask this frame belongs to
|
||||||
|
found = False
|
||||||
|
for name, start, end in zip(subtask_names, subtask_start_frames, subtask_end_frames, strict=True):
|
||||||
|
if start <= current_frame <= end:
|
||||||
|
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else 0
|
||||||
|
tau = compute_tau(current_frame, start, end)
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
# Frame between subtasks - use previous subtask's end state
|
||||||
|
if not found:
|
||||||
|
for j in range(len(subtask_names) - 1):
|
||||||
|
if subtask_end_frames[j] < current_frame < subtask_start_frames[j + 1]:
|
||||||
|
name = subtask_names[j]
|
||||||
|
stage_idx = global_subtask_names.index(name) if name in global_subtask_names else j
|
||||||
|
tau = 1.0
|
||||||
|
break
|
||||||
|
|
||||||
|
if return_combined:
|
||||||
|
# Clamp to avoid overflow at end
|
||||||
|
if stage_idx >= num_stages - 1 and tau >= 1.0:
|
||||||
|
return num_stages - 1 + 0.999
|
||||||
|
return stage_idx + tau
|
||||||
|
return stage_idx, tau
|
||||||
|
|
||||||
|
|
||||||
|
def compute_absolute_indices(
|
||||||
|
frame_idx: int,
|
||||||
|
ep_start: int,
|
||||||
|
ep_end: int,
|
||||||
|
n_obs_steps: int,
|
||||||
|
frame_gap: int = 30,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute absolute frame indices with clamping for bidirectional observation sequence.
|
||||||
|
|
||||||
|
Bidirectional sampling centered on target frame:
|
||||||
|
- Before: [-frame_gap * half_steps, ..., -frame_gap] (half_steps frames)
|
||||||
|
- Current: [0] (1 frame)
|
||||||
|
- After: [frame_gap, ..., frame_gap * half_steps] (half_steps frames)
|
||||||
|
- Total: n_obs_steps + 1 frames
|
||||||
|
|
||||||
|
Out-of-bounds frames are clamped (duplicated from boundary).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx: Target frame index (center frame of sequence)
|
||||||
|
ep_start: Episode start index
|
||||||
|
ep_end: Episode end index (exclusive)
|
||||||
|
n_obs_steps: Number of observation steps (must be even for symmetric sampling)
|
||||||
|
frame_gap: Gap between observation frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (indices, out_of_bounds_flags)
|
||||||
|
"""
|
||||||
|
half_steps = n_obs_steps // 2
|
||||||
|
|
||||||
|
# Bidirectional deltas: past + current + future
|
||||||
|
past_deltas = [-frame_gap * i for i in range(half_steps, 0, -1)]
|
||||||
|
future_deltas = [frame_gap * i for i in range(1, half_steps + 1)]
|
||||||
|
delta_indices = past_deltas + [0] + future_deltas
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
out_of_bounds = []
|
||||||
|
|
||||||
|
for delta in delta_indices:
|
||||||
|
target_idx = frame_idx + delta
|
||||||
|
# Clamp to episode bounds (duplicate boundary frames for out-of-bounds)
|
||||||
|
clamped_idx = max(ep_start, min(ep_end - 1, target_idx))
|
||||||
|
frames.append(clamped_idx)
|
||||||
|
# Flag as out-of-bounds if clamping occurred
|
||||||
|
out_of_bounds.append(1 if target_idx != clamped_idx else 0)
|
||||||
|
|
||||||
|
return torch.tensor(frames), torch.tensor(out_of_bounds)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rewind_augmentation(
|
||||||
|
frame_idx: int,
|
||||||
|
ep_start: int,
|
||||||
|
n_obs_steps: int,
|
||||||
|
max_rewind_steps: int,
|
||||||
|
frame_gap: int = 30,
|
||||||
|
rewind_step: int | None = None,
|
||||||
|
) -> tuple[int, list[int]]:
|
||||||
|
"""
|
||||||
|
Generate rewind frame indices for temporal augmentation.
|
||||||
|
|
||||||
|
Rewind simulates going backwards through previously seen frames,
|
||||||
|
starting from before the earliest observation frame (for bidirectional sampling).
|
||||||
|
Appends reversed frames after the observation sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx: Target frame index (center of bidirectional observation window)
|
||||||
|
ep_start: Episode start index
|
||||||
|
n_obs_steps: Number of observation steps
|
||||||
|
max_rewind_steps: Maximum rewind steps
|
||||||
|
frame_gap: Gap between frames
|
||||||
|
rewind_step: If provided, use this exact rewind step (for deterministic behavior).
|
||||||
|
If None, sample randomly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (rewind_step, rewind_indices)
|
||||||
|
"""
|
||||||
|
# For bidirectional sampling, earliest obs frame is at frame_idx - half_steps * frame_gap
|
||||||
|
half_steps = n_obs_steps // 2
|
||||||
|
earliest_obs_frame = frame_idx - half_steps * frame_gap
|
||||||
|
|
||||||
|
# Required history: frames before earliest observation frame
|
||||||
|
if earliest_obs_frame <= ep_start:
|
||||||
|
return 0, [] # No history before observation window
|
||||||
|
|
||||||
|
# Max valid rewind steps based on available history before earliest obs frame
|
||||||
|
available_history = earliest_obs_frame - ep_start
|
||||||
|
max_valid_step = available_history // frame_gap
|
||||||
|
max_rewind = min(max_rewind_steps, max(0, max_valid_step))
|
||||||
|
|
||||||
|
if max_rewind <= 0:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
# Sample rewind steps if not provided
|
||||||
|
rewind_step = random.randint(1, max_rewind) if rewind_step is None else min(rewind_step, max_rewind)
|
||||||
|
|
||||||
|
if rewind_step == 0:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
# Generate rewind indices going backwards from earliest obs frame
|
||||||
|
# rewind_indices[0] is closest to obs window, rewind_indices[-1] is furthest back
|
||||||
|
rewind_indices = []
|
||||||
|
for i in range(1, rewind_step + 1):
|
||||||
|
idx = earliest_obs_frame - i * frame_gap
|
||||||
|
idx = max(ep_start, idx) # Clamp to episode start
|
||||||
|
rewind_indices.append(idx)
|
||||||
|
|
||||||
|
return rewind_step, rewind_indices
|
||||||
|
|
||||||
|
|
||||||
|
def compute_tau(current_frame: int | float, subtask_start: int | float, subtask_end: int | float) -> float:
|
||||||
|
"""Compute τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]. Returns 1.0 for zero-duration subtasks."""
|
||||||
|
duration = subtask_end - subtask_start
|
||||||
|
if duration <= 0:
|
||||||
|
return 1.0
|
||||||
|
return float(np.clip((current_frame - subtask_start) / duration, 0.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||||
|
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||||
|
current_dim = state.shape[-1]
|
||||||
|
if current_dim >= max_state_dim:
|
||||||
|
return state[..., :max_state_dim] # Truncate if larger
|
||||||
|
|
||||||
|
# Pad with zeros on the right
|
||||||
|
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||||
|
return F.pad(state, padding, mode="constant", value=0)
|
||||||
|
|
||||||
|
|
||||||
|
def temporal_proportions_to_breakpoints(
|
||||||
|
temporal_proportions: dict[str, float] | list[float] | None,
|
||||||
|
subtask_names: list[str] | None = None,
|
||||||
|
) -> list[float] | None:
|
||||||
|
"""Convert temporal proportions to cumulative breakpoints for normalization."""
|
||||||
|
if temporal_proportions is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(temporal_proportions, dict):
|
||||||
|
if subtask_names is not None:
|
||||||
|
proportions = [temporal_proportions.get(name, 0.0) for name in subtask_names]
|
||||||
|
else:
|
||||||
|
proportions = list(temporal_proportions.values())
|
||||||
|
else:
|
||||||
|
proportions = list(temporal_proportions)
|
||||||
|
|
||||||
|
total = sum(proportions)
|
||||||
|
if total > 0 and abs(total - 1.0) > 1e-6:
|
||||||
|
proportions = [p / total for p in proportions]
|
||||||
|
|
||||||
|
breakpoints = [0.0]
|
||||||
|
cumsum = 0.0
|
||||||
|
for prop in proportions:
|
||||||
|
cumsum += prop
|
||||||
|
breakpoints.append(cumsum)
|
||||||
|
breakpoints[-1] = 1.0
|
||||||
|
|
||||||
|
return breakpoints
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_stage_tau(
|
||||||
|
x: float | torch.Tensor,
|
||||||
|
num_stages: int | None = None,
|
||||||
|
breakpoints: list[float] | None = None,
|
||||||
|
temporal_proportions: dict[str, float] | list[float] | None = None,
|
||||||
|
subtask_names: list[str] | None = None,
|
||||||
|
) -> float | torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize stage+tau reward to [0, 1] with custom breakpoints.
|
||||||
|
|
||||||
|
Maps stage index + within-stage tau to normalized progress [0, 1].
|
||||||
|
The breakpoints are designed to give appropriate weight to each stage
|
||||||
|
based on their importance in the task (using temporal proportions).
|
||||||
|
|
||||||
|
Priority: breakpoints > temporal_proportions > linear fallback
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Raw reward value (stage index + tau) where stage ∈ [0, num_stages-1] and tau ∈ [0, 1)
|
||||||
|
num_stages: Number of stages (required if breakpoints/proportions not provided)
|
||||||
|
breakpoints: Optional custom breakpoints list of length num_stages + 1.
|
||||||
|
temporal_proportions: Optional temporal proportions dict/list to compute breakpoints.
|
||||||
|
subtask_names: Optional ordered list of subtask names (for dict proportions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized progress value ∈ [0, 1]
|
||||||
|
"""
|
||||||
|
if breakpoints is not None:
|
||||||
|
num_stages = len(breakpoints) - 1
|
||||||
|
elif temporal_proportions is not None:
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(temporal_proportions, subtask_names)
|
||||||
|
num_stages = len(breakpoints) - 1
|
||||||
|
elif num_stages is not None:
|
||||||
|
breakpoints = [i / num_stages for i in range(num_stages + 1)]
|
||||||
|
else:
|
||||||
|
raise ValueError("Either num_stages, breakpoints, or temporal_proportions must be provided")
|
||||||
|
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
result = torch.zeros_like(x)
|
||||||
|
for i in range(num_stages):
|
||||||
|
mask = (x >= i) & (x < i + 1)
|
||||||
|
tau_in_stage = x - i
|
||||||
|
result[mask] = breakpoints[i] + tau_in_stage[mask] * (breakpoints[i + 1] - breakpoints[i])
|
||||||
|
result[x >= num_stages] = 1.0
|
||||||
|
return result.clamp(0.0, 1.0)
|
||||||
|
else:
|
||||||
|
if x < 0:
|
||||||
|
return 0.0
|
||||||
|
if x >= num_stages:
|
||||||
|
return 1.0
|
||||||
|
stage = int(x)
|
||||||
|
tau = x - stage
|
||||||
|
return breakpoints[stage] + tau * (breakpoints[stage + 1] - breakpoints[stage])
|
||||||
@@ -231,6 +231,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SmolVLAConfig,
|
config: SmolVLAConfig,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -352,8 +353,19 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
def _rtc_enabled(self) -> bool:
|
def _rtc_enabled(self) -> bool:
|
||||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
def forward(
|
||||||
"""Do a full training forward pass to compute the loss"""
|
self, batch: dict[str, Tensor], noise=None, time=None, reduction: str = "mean"
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
"""Do a full training forward pass to compute the loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch containing observations and actions.
|
||||||
|
noise: Optional noise tensor for flow matching.
|
||||||
|
time: Optional time tensor for flow matching.
|
||||||
|
reduction: How to reduce the loss. Options:
|
||||||
|
- "mean": Return scalar mean loss (default, backward compatible)
|
||||||
|
- "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting
|
||||||
|
"""
|
||||||
if self.config.adapt_to_pi_aloha:
|
if self.config.adapt_to_pi_aloha:
|
||||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
@@ -377,11 +389,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
losses = losses[:, :, : self.config.max_action_dim]
|
losses = losses[:, :, : self.config.max_action_dim]
|
||||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||||
|
|
||||||
# For backward pass
|
if reduction == "none":
|
||||||
loss = losses.mean()
|
# Return per-sample losses (B,) by averaging over time and action dims
|
||||||
# For backward pass
|
per_sample_loss = losses.mean(dim=(1, 2))
|
||||||
loss_dict["loss"] = loss.item()
|
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||||
return loss, loss_dict
|
return per_sample_loss, loss_dict
|
||||||
|
else:
|
||||||
|
# Default: return scalar mean loss
|
||||||
|
loss = losses.mean()
|
||||||
|
loss_dict["loss"] = loss.item()
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
def prepare_images(self, batch):
|
def prepare_images(self, batch):
|
||||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TDMPCConfig,
|
config: TDMPCConfig,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -231,11 +231,20 @@ def validate_visual_features_consistency(
|
|||||||
"""
|
"""
|
||||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||||
|
|
||||||
|
Validation passes if EITHER:
|
||||||
|
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||||
|
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||||
"""
|
"""
|
||||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||||
if not provided_visuals.issubset(expected_visuals):
|
|
||||||
|
# Accept if either direction is a subset
|
||||||
|
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||||
|
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||||
|
|
||||||
|
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: VQBeTConfig | None = None,
|
config: VQBeTConfig | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
config_class = XVLAConfig
|
config_class = XVLAConfig
|
||||||
name = "xvla"
|
name = "xvla"
|
||||||
|
|
||||||
def __init__(self, config: XVLAConfig):
|
def __init__(self, config: XVLAConfig, **kwargs):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
florence_config = config.get_florence_config()
|
florence_config = config.get_florence_config()
|
||||||
|
|||||||
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
|||||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||||
|
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||||
|
|
||||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def update_policy(
|
|||||||
accelerator: Accelerator,
|
accelerator: Accelerator,
|
||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
lock=None,
|
lock=None,
|
||||||
|
rabc_weights_provider=None,
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
@@ -78,6 +79,7 @@ def update_policy(
|
|||||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||||
lr_scheduler: An optional learning rate scheduler.
|
lr_scheduler: An optional learning rate scheduler.
|
||||||
lock: An optional lock for thread-safe optimizer updates.
|
lock: An optional lock for thread-safe optimizer updates.
|
||||||
|
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
@@ -87,9 +89,30 @@ def update_policy(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
# Get RA-BC weights if enabled
|
||||||
|
rabc_batch_weights = None
|
||||||
|
rabc_batch_stats = None
|
||||||
|
if rabc_weights_provider is not None:
|
||||||
|
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||||
|
|
||||||
# Let accelerator handle mixed precision
|
# Let accelerator handle mixed precision
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
loss, output_dict = policy.forward(batch)
|
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||||
|
if rabc_batch_weights is not None:
|
||||||
|
# Get per-sample losses
|
||||||
|
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||||
|
|
||||||
|
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||||
|
# rabc_batch_weights is already normalized to sum to batch_size
|
||||||
|
epsilon = 1e-6
|
||||||
|
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||||
|
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||||
|
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||||
|
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||||
|
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||||
|
else:
|
||||||
|
loss, output_dict = policy.forward(batch)
|
||||||
|
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
|
||||||
# Use accelerator's backward method
|
# Use accelerator's backward method
|
||||||
@@ -141,8 +164,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||||
"""
|
"""
|
||||||
cfg.validate()
|
|
||||||
|
|
||||||
# Create Accelerator if not provided
|
# Create Accelerator if not provided
|
||||||
# It will automatically detect if running in distributed mode or single-process mode
|
# It will automatically detect if running in distributed mode or single-process mode
|
||||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
@@ -159,6 +180,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
|
cfg.validate()
|
||||||
|
|
||||||
# Only log on main process
|
# Only log on main process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info(pformat(cfg.to_dict()))
|
logging.info(pformat(cfg.to_dict()))
|
||||||
@@ -217,6 +240,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
# Only provide dataset_stats when not resuming from saved processor state
|
# Only provide dataset_stats when not resuming from saved processor state
|
||||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
|
# For SARM, always provide dataset_meta for progress normalization
|
||||||
|
if cfg.policy.type == "sarm":
|
||||||
|
processor_kwargs["dataset_meta"] = dataset.meta
|
||||||
|
|
||||||
if cfg.policy.pretrained_path is not None:
|
if cfg.policy.pretrained_path is not None:
|
||||||
processor_kwargs["preprocessor_overrides"] = {
|
processor_kwargs["preprocessor_overrides"] = {
|
||||||
"device_processor": {"device": device.type},
|
"device_processor": {"device": device.type},
|
||||||
@@ -248,6 +275,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
|
||||||
|
# Load precomputed SARM progress for RA-BC if enabled
|
||||||
|
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||||
|
rabc_weights = None
|
||||||
|
if cfg.use_rabc:
|
||||||
|
from lerobot.utils.rabc import RABCWeights
|
||||||
|
|
||||||
|
# Get chunk_size from policy config
|
||||||
|
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||||
|
if chunk_size is None:
|
||||||
|
raise ValueError("Chunk size is not found in policy config")
|
||||||
|
|
||||||
|
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||||
|
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||||
|
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||||
|
rabc_weights = RABCWeights(
|
||||||
|
progress_path=cfg.rabc_progress_path,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
head_mode=head_mode,
|
||||||
|
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||||
|
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
@@ -327,7 +377,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info(
|
||||||
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -343,6 +395,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
|
rabc_weights_provider=rabc_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
@@ -359,6 +412,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
wandb_log_dict = train_tracker.to_dict()
|
wandb_log_dict = train_tracker.to_dict()
|
||||||
if output_dict:
|
if output_dict:
|
||||||
wandb_log_dict.update(output_dict)
|
wandb_log_dict.update(output_dict)
|
||||||
|
# Log RA-BC statistics if enabled
|
||||||
|
if rabc_weights is not None:
|
||||||
|
rabc_stats = rabc_weights.get_stats()
|
||||||
|
wandb_log_dict.update(
|
||||||
|
{
|
||||||
|
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||||
|
"rabc_delta_std": rabc_stats["delta_std"],
|
||||||
|
"rabc_num_frames": rabc_stats["num_frames"],
|
||||||
|
}
|
||||||
|
)
|
||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,276 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RABCWeights:
|
||||||
|
"""
|
||||||
|
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||||
|
|
||||||
|
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||||
|
During training, computes:
|
||||||
|
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||||
|
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_path: Path to parquet file with precomputed progress values
|
||||||
|
chunk_size: Number of frames ahead for computing progress delta
|
||||||
|
head_mode: Which SARM head to use ("sparse" or "dense")
|
||||||
|
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||||
|
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||||
|
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||||||
|
device: Device to return tensors on
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
progress_path: str | Path,
|
||||||
|
chunk_size: int = 50,
|
||||||
|
head_mode: str = "sparse",
|
||||||
|
kappa: float = 0.01,
|
||||||
|
epsilon: float = 1e-6,
|
||||||
|
fallback_weight: float = 1.0,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
self.progress_path = Path(progress_path)
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.head_mode = head_mode
|
||||||
|
self.kappa = kappa
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.fallback_weight = fallback_weight
|
||||||
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Determine progress column name
|
||||||
|
self.progress_column = f"progress_{head_mode}"
|
||||||
|
|
||||||
|
# Load progress values
|
||||||
|
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||||||
|
self.df = pd.read_parquet(self.progress_path)
|
||||||
|
|
||||||
|
# Check if the requested head mode column exists
|
||||||
|
if self.progress_column not in self.df.columns:
|
||||||
|
available = [c for c in self.df.columns if c.startswith("progress")]
|
||||||
|
raise ValueError(
|
||||||
|
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Using progress column: {self.progress_column}")
|
||||||
|
|
||||||
|
self.progress_lookup = {}
|
||||||
|
self.episode_lookup = {}
|
||||||
|
|
||||||
|
for _, row in self.df.iterrows():
|
||||||
|
global_idx = int(row["index"])
|
||||||
|
progress = row[self.progress_column]
|
||||||
|
episode_idx = int(row["episode_index"])
|
||||||
|
|
||||||
|
if not np.isnan(progress):
|
||||||
|
self.progress_lookup[global_idx] = float(progress)
|
||||||
|
self.episode_lookup[global_idx] = episode_idx
|
||||||
|
|
||||||
|
# Build episode boundaries for delta computation
|
||||||
|
self.episode_boundaries = {}
|
||||||
|
for episode_idx in self.df["episode_index"].unique():
|
||||||
|
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||||
|
self.episode_boundaries[int(episode_idx)] = {
|
||||||
|
"start": int(ep_df["index"].min()),
|
||||||
|
"end": int(ep_df["index"].max()) + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||||||
|
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||||||
|
|
||||||
|
# Compute global statistics for weight computation
|
||||||
|
self._compute_global_stats()
|
||||||
|
|
||||||
|
def _compute_global_stats(self):
|
||||||
|
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||||
|
all_deltas = []
|
||||||
|
|
||||||
|
for global_idx, progress in self.progress_lookup.items():
|
||||||
|
episode_idx = self.episode_lookup.get(global_idx)
|
||||||
|
if episode_idx is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bounds = self.episode_boundaries.get(episode_idx)
|
||||||
|
if bounds is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
future_idx = global_idx + self.chunk_size
|
||||||
|
if future_idx >= bounds["end"]:
|
||||||
|
# Near end of episode: use last frame's progress
|
||||||
|
future_idx = bounds["end"] - 1
|
||||||
|
|
||||||
|
future_progress = self.progress_lookup.get(future_idx)
|
||||||
|
if future_progress is not None:
|
||||||
|
delta = future_progress - progress
|
||||||
|
all_deltas.append(delta)
|
||||||
|
|
||||||
|
if all_deltas:
|
||||||
|
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||||||
|
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||||||
|
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||||
|
else:
|
||||||
|
self.delta_mean = 0.0
|
||||||
|
self.delta_std = self.epsilon
|
||||||
|
logging.warning("No valid progress deltas found, using default stats")
|
||||||
|
|
||||||
|
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""
|
||||||
|
Compute RA-BC weights for a batch.
|
||||||
|
|
||||||
|
For each sample:
|
||||||
|
1. Get progress at current frame
|
||||||
|
2. Get progress at frame + chunk_size (within same episode)
|
||||||
|
3. Compute delta = future_progress - current_progress
|
||||||
|
4. Compute weight using paper Eq. 8-9
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Training batch containing "index" key with global frame indices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of:
|
||||||
|
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||||||
|
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||||||
|
"""
|
||||||
|
indices = batch.get("index")
|
||||||
|
if indices is None:
|
||||||
|
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||||
|
batch_size = self._get_batch_size(batch)
|
||||||
|
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||||||
|
|
||||||
|
# Convert to list of ints
|
||||||
|
if isinstance(indices, torch.Tensor):
|
||||||
|
indices = indices.cpu().numpy().tolist()
|
||||||
|
elif isinstance(indices, np.ndarray):
|
||||||
|
indices = indices.tolist()
|
||||||
|
|
||||||
|
# Compute deltas and weights for each sample
|
||||||
|
deltas = []
|
||||||
|
for idx in indices:
|
||||||
|
idx = int(idx)
|
||||||
|
delta = self._compute_delta(idx)
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
deltas = np.array(deltas, dtype=np.float32)
|
||||||
|
|
||||||
|
# Compute weights from deltas
|
||||||
|
weights = self._compute_weights(deltas)
|
||||||
|
|
||||||
|
# Compute stats before normalization for logging
|
||||||
|
raw_mean_weight = float(np.nanmean(weights))
|
||||||
|
num_zero_weight = int(np.sum(weights == 0))
|
||||||
|
num_full_weight = int(np.sum(weights == 1.0))
|
||||||
|
batch_stats = {
|
||||||
|
"raw_mean_weight": raw_mean_weight,
|
||||||
|
"num_zero_weight": num_zero_weight,
|
||||||
|
"num_full_weight": num_full_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Normalize to sum to batch_size
|
||||||
|
batch_size = len(weights)
|
||||||
|
weight_sum = weights.sum() + self.epsilon
|
||||||
|
weights = weights * batch_size / weight_sum
|
||||||
|
|
||||||
|
return weights, batch_stats
|
||||||
|
|
||||||
|
def _compute_delta(self, global_idx: int) -> float:
|
||||||
|
"""Compute progress delta for a single frame."""
|
||||||
|
current_progress = self.progress_lookup.get(global_idx)
|
||||||
|
if current_progress is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
episode_idx = self.episode_lookup.get(global_idx)
|
||||||
|
if episode_idx is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
bounds = self.episode_boundaries.get(episode_idx)
|
||||||
|
if bounds is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||||||
|
if future_idx >= bounds["end"]:
|
||||||
|
# Near end of episode: use last frame's progress instead
|
||||||
|
future_idx = bounds["end"] - 1
|
||||||
|
|
||||||
|
future_progress = self.progress_lookup.get(future_idx)
|
||||||
|
if future_progress is None:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
return future_progress - current_progress
|
||||||
|
|
||||||
|
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute RA-BC weights from progress deltas.
|
||||||
|
|
||||||
|
Following paper Eq. 8-9:
|
||||||
|
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||||||
|
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of weights
|
||||||
|
"""
|
||||||
|
valid_mask = ~np.isnan(deltas)
|
||||||
|
|
||||||
|
# Compute soft weights using global statistics
|
||||||
|
lower_bound = self.delta_mean - 2 * self.delta_std
|
||||||
|
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||||||
|
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||||||
|
|
||||||
|
# Apply paper's Eq. 9
|
||||||
|
weights = np.zeros_like(deltas, dtype=np.float32)
|
||||||
|
|
||||||
|
# High quality: ri > kappa → weight = 1
|
||||||
|
high_quality_mask = deltas > self.kappa
|
||||||
|
weights[high_quality_mask] = 1.0
|
||||||
|
|
||||||
|
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||||||
|
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||||||
|
weights[moderate_mask] = soft_weights[moderate_mask]
|
||||||
|
|
||||||
|
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||||||
|
# Invalid (NaN): use fallback weight
|
||||||
|
weights[~valid_mask] = self.fallback_weight
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def _get_batch_size(self, batch: dict) -> int:
|
||||||
|
"""Determine batch size from batch."""
|
||||||
|
for key in ["action", "index"]:
|
||||||
|
if key in batch:
|
||||||
|
val = batch[key]
|
||||||
|
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||||
|
return val.shape[0]
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""Get statistics."""
|
||||||
|
return {
|
||||||
|
"num_frames": len(self.progress_lookup),
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
"head_mode": self.head_mode,
|
||||||
|
"delta_mean": self.delta_mean,
|
||||||
|
"delta_std": self.delta_std,
|
||||||
|
"kappa": self.kappa,
|
||||||
|
}
|
||||||
@@ -0,0 +1,694 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("faker")
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor.core import TransitionKey
|
||||||
|
|
||||||
|
|
||||||
|
class MockDatasetMeta:
|
||||||
|
"""Mock dataset metadata for testing processor."""
|
||||||
|
|
||||||
|
def __init__(self, episodes: list[dict]):
|
||||||
|
self._episodes = episodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episodes(self):
|
||||||
|
"""Return episodes as a mock object with to_pandas() method."""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.__len__ = lambda s: len(self._episodes)
|
||||||
|
mock.__getitem__ = lambda s, idx: self._episodes[idx]
|
||||||
|
mock.to_pandas = lambda: pd.DataFrame(self._episodes)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
class MockConfig:
|
||||||
|
"""Mock SARMConfig for testing processor methods."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_obs_steps: int = 8,
|
||||||
|
max_rewind_steps: int = 4,
|
||||||
|
frame_gap: int = 30,
|
||||||
|
sparse_subtask_names: list = None,
|
||||||
|
sparse_temporal_proportions: list = None,
|
||||||
|
dense_subtask_names: list = None,
|
||||||
|
dense_temporal_proportions: list = None,
|
||||||
|
image_key: str = "observation.images.top",
|
||||||
|
state_key: str = "observation.state",
|
||||||
|
max_state_dim: int = 32,
|
||||||
|
device: str = None,
|
||||||
|
rewind_probability: float = 0.8,
|
||||||
|
language_perturbation_probability: float = 0.2,
|
||||||
|
annotation_mode: str = "dual",
|
||||||
|
clip_batch_size: int = 64,
|
||||||
|
text_dim: int = 512,
|
||||||
|
):
|
||||||
|
self.n_obs_steps = n_obs_steps
|
||||||
|
self.max_rewind_steps = max_rewind_steps
|
||||||
|
self.frame_gap = frame_gap
|
||||||
|
self.sparse_subtask_names = sparse_subtask_names or ["task"]
|
||||||
|
self.sparse_temporal_proportions = sparse_temporal_proportions or [1.0]
|
||||||
|
self.dense_subtask_names = dense_subtask_names
|
||||||
|
self.dense_temporal_proportions = dense_temporal_proportions
|
||||||
|
self.uses_dual_heads = annotation_mode in ["dense_only", "dual"]
|
||||||
|
self.image_key = image_key
|
||||||
|
self.state_key = state_key
|
||||||
|
self.max_state_dim = max_state_dim
|
||||||
|
self.device = device
|
||||||
|
self.rewind_probability = rewind_probability
|
||||||
|
self.language_perturbation_probability = language_perturbation_probability
|
||||||
|
self.annotation_mode = annotation_mode
|
||||||
|
self.clip_batch_size = clip_batch_size
|
||||||
|
self.text_dim = text_dim
|
||||||
|
|
||||||
|
# Compute observation delta indices (same as config: bidirectional)
|
||||||
|
half_steps = self.n_obs_steps // 2
|
||||||
|
past_deltas = [-self.frame_gap * i for i in range(half_steps, 0, -1)]
|
||||||
|
future_deltas = [self.frame_gap * i for i in range(1, half_steps + 1)]
|
||||||
|
obs_deltas = past_deltas + [0] + future_deltas
|
||||||
|
rewind_deltas = [-self.frame_gap * (i + 1) for i in range(self.max_rewind_steps)]
|
||||||
|
self.observation_delta_indices = obs_deltas + rewind_deltas
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames(self) -> int:
|
||||||
|
return 1 + self.n_obs_steps + self.max_rewind_steps
|
||||||
|
|
||||||
|
|
||||||
|
class TestSARMEncodingProcessorStepEndToEnd:
|
||||||
|
"""End-to-end test for SARMEncodingProcessorStep with dummy batch data."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_clip_model(self):
|
||||||
|
"""Mock CLIP model to avoid loading real weights."""
|
||||||
|
with (
|
||||||
|
patch("lerobot.policies.sarm.processor_sarm.CLIPModel") as mock_model_cls,
|
||||||
|
patch("lerobot.policies.sarm.processor_sarm.CLIPProcessor") as mock_processor_cls,
|
||||||
|
):
|
||||||
|
# Mock the CLIP model - return embeddings based on input batch size
|
||||||
|
mock_model = MagicMock()
|
||||||
|
|
||||||
|
def get_image_features_side_effect(**kwargs):
|
||||||
|
pixel_values = kwargs.get("pixel_values")
|
||||||
|
batch_size = pixel_values.shape[0] if pixel_values is not None else 1
|
||||||
|
return torch.randn(batch_size, 512)
|
||||||
|
|
||||||
|
mock_model.get_image_features.side_effect = get_image_features_side_effect
|
||||||
|
mock_model.get_text_features.return_value = torch.randn(1, 512)
|
||||||
|
mock_model.to.return_value = mock_model
|
||||||
|
mock_model_cls.from_pretrained.return_value = mock_model
|
||||||
|
|
||||||
|
# Mock the CLIP processor - return tensors based on input images
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
|
||||||
|
def processor_side_effect(images=None, **kwargs):
|
||||||
|
num_images = len(images) if images is not None else 1
|
||||||
|
return {
|
||||||
|
"pixel_values": torch.randn(num_images, 3, 224, 224),
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_processor.side_effect = processor_side_effect
|
||||||
|
# Mock tokenizer for text encoding
|
||||||
|
mock_processor.tokenizer.return_value = {
|
||||||
|
"input_ids": torch.ones(1, 77, dtype=torch.long),
|
||||||
|
"attention_mask": torch.ones(1, 77, dtype=torch.long),
|
||||||
|
}
|
||||||
|
mock_processor_cls.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
yield mock_model, mock_processor
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def processor_with_mocks(self, mock_clip_model):
|
||||||
|
"""Create a processor with mocked CLIP and dataset metadata for dual mode."""
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
|
# Dual mode config with both sparse and dense annotations
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=30,
|
||||||
|
rewind_probability=0.0, # Disable for deterministic test
|
||||||
|
language_perturbation_probability=0.0, # Disable for deterministic test
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["reach", "grasp", "lift"],
|
||||||
|
sparse_temporal_proportions=[0.3, 0.4, 0.3],
|
||||||
|
dense_subtask_names=["approach", "contact", "close_gripper", "lift_up"],
|
||||||
|
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create mock dataset metadata with one episode of 300 frames
|
||||||
|
# Include annotation columns for dual mode
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 300,
|
||||||
|
"task": "pick up the cube",
|
||||||
|
"sparse_subtask_names": ["reach", "grasp", "lift"],
|
||||||
|
"sparse_subtask_start_frames": [0, 90, 210],
|
||||||
|
"sparse_subtask_end_frames": [90, 210, 300],
|
||||||
|
"dense_subtask_names": ["approach", "contact", "close_gripper", "lift_up"],
|
||||||
|
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||||
|
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(
|
||||||
|
config=config,
|
||||||
|
dataset_meta=dataset_meta,
|
||||||
|
)
|
||||||
|
processor.train(True) # Use train() method, not direct assignment
|
||||||
|
|
||||||
|
return processor, config
|
||||||
|
|
||||||
|
def test_call_with_single_frame_batch(self, processor_with_mocks):
|
||||||
|
"""Test processor __call__ with a single-frame batch."""
|
||||||
|
processor, config = processor_with_mocks
|
||||||
|
|
||||||
|
# Create dummy input transition
|
||||||
|
batch_size = 1
|
||||||
|
num_frames = config.num_frames # 13 frames (9 obs + 4 rewind)
|
||||||
|
|
||||||
|
# Image: (T, C, H, W) format as expected by processor
|
||||||
|
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
|
||||||
|
# State: (T, D) format
|
||||||
|
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": 150, # Middle of episode
|
||||||
|
"episode_index": 0,
|
||||||
|
"task": "pick up the cube",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run processor
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
# Verify output structure
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Check video features exist and have correct shape
|
||||||
|
assert "video_features" in obs
|
||||||
|
video_features = obs["video_features"]
|
||||||
|
assert video_features.shape[0] == batch_size
|
||||||
|
assert video_features.shape[1] == num_frames
|
||||||
|
assert video_features.shape[2] == 512 # CLIP embedding dim
|
||||||
|
|
||||||
|
# Check state features exist and have correct shape
|
||||||
|
assert "state_features" in obs
|
||||||
|
state_features = obs["state_features"]
|
||||||
|
assert state_features.shape[0] == batch_size
|
||||||
|
assert state_features.shape[1] == num_frames
|
||||||
|
assert state_features.shape[2] == config.max_state_dim # Padded to max_state_dim
|
||||||
|
|
||||||
|
# Check text features exist and have correct shape
|
||||||
|
assert "text_features" in obs
|
||||||
|
text_features = obs["text_features"]
|
||||||
|
assert text_features.shape[0] == batch_size
|
||||||
|
assert text_features.shape[1] == 512 # CLIP embedding dim
|
||||||
|
|
||||||
|
# Check lengths tensor
|
||||||
|
assert "lengths" in obs
|
||||||
|
lengths = obs["lengths"]
|
||||||
|
assert lengths.shape[0] == batch_size
|
||||||
|
assert lengths.dtype == torch.int32
|
||||||
|
|
||||||
|
# Check sparse_targets exist
|
||||||
|
assert "sparse_targets" in obs
|
||||||
|
sparse_targets = obs["sparse_targets"]
|
||||||
|
assert sparse_targets.shape == (batch_size, num_frames)
|
||||||
|
# All targets should be in [0, max_stages] range (stage.tau format)
|
||||||
|
assert (sparse_targets >= 0).all()
|
||||||
|
|
||||||
|
# Check dense_targets exist (for dual mode)
|
||||||
|
assert "dense_targets" in obs
|
||||||
|
dense_targets = obs["dense_targets"]
|
||||||
|
assert dense_targets.shape == (batch_size, num_frames)
|
||||||
|
assert (dense_targets >= 0).all()
|
||||||
|
|
||||||
|
def test_call_with_batched_input(self, mock_clip_model):
|
||||||
|
"""Test processor __call__ with a batched input (multiple frames) in dual mode."""
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=30,
|
||||||
|
rewind_probability=0.0,
|
||||||
|
language_perturbation_probability=0.0,
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["reach", "grasp"],
|
||||||
|
sparse_temporal_proportions=[0.5, 0.5],
|
||||||
|
dense_subtask_names=["step1", "step2", "step3"],
|
||||||
|
dense_temporal_proportions=[0.33, 0.34, 0.33],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Two episodes with different lengths, each with sparse+dense annotations
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 200,
|
||||||
|
"task": "task A",
|
||||||
|
"sparse_subtask_names": ["reach", "grasp"],
|
||||||
|
"sparse_subtask_start_frames": [0, 100],
|
||||||
|
"sparse_subtask_end_frames": [100, 200],
|
||||||
|
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||||
|
"dense_subtask_start_frames": [0, 66, 133],
|
||||||
|
"dense_subtask_end_frames": [66, 133, 200],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dataset_from_index": 200,
|
||||||
|
"dataset_to_index": 500,
|
||||||
|
"task": "task B",
|
||||||
|
"sparse_subtask_names": ["reach", "grasp"],
|
||||||
|
"sparse_subtask_start_frames": [200, 350],
|
||||||
|
"sparse_subtask_end_frames": [350, 500],
|
||||||
|
"dense_subtask_names": ["step1", "step2", "step3"],
|
||||||
|
"dense_subtask_start_frames": [200, 300, 400],
|
||||||
|
"dense_subtask_end_frames": [300, 400, 500],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||||
|
processor.train(True)
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
num_frames = config.num_frames
|
||||||
|
|
||||||
|
# Image: (B, T, C, H, W) format
|
||||||
|
dummy_image = np.random.rand(batch_size, num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
dummy_state = np.random.rand(batch_size, num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": np.array([100, 350]), # One frame from each episode
|
||||||
|
"episode_index": np.array([0, 1]),
|
||||||
|
"task": ["task A", "task B"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
# Verify batch dimension is preserved for all outputs
|
||||||
|
assert obs["video_features"].shape[0] == batch_size
|
||||||
|
assert obs["state_features"].shape[0] == batch_size
|
||||||
|
assert obs["lengths"].shape[0] == batch_size
|
||||||
|
assert obs["sparse_targets"].shape[0] == batch_size
|
||||||
|
assert obs["dense_targets"].shape[0] == batch_size # Dual mode has dense targets
|
||||||
|
|
||||||
|
def test_targets_increase_with_progress(self, mock_clip_model):
|
||||||
|
"""Test that both sparse and dense targets increase as frame index progresses."""
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=30,
|
||||||
|
rewind_probability=0.0,
|
||||||
|
language_perturbation_probability=0.0,
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["phase1", "phase2"],
|
||||||
|
sparse_temporal_proportions=[0.5, 0.5],
|
||||||
|
dense_subtask_names=["a", "b", "c", "d"],
|
||||||
|
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 300,
|
||||||
|
"task": "test task",
|
||||||
|
"sparse_subtask_names": ["phase1", "phase2"],
|
||||||
|
"sparse_subtask_start_frames": [0, 150],
|
||||||
|
"sparse_subtask_end_frames": [150, 300],
|
||||||
|
"dense_subtask_names": ["a", "b", "c", "d"],
|
||||||
|
"dense_subtask_start_frames": [0, 75, 150, 225],
|
||||||
|
"dense_subtask_end_frames": [75, 150, 225, 300],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||||
|
processor.train(True)
|
||||||
|
|
||||||
|
num_frames = config.num_frames
|
||||||
|
|
||||||
|
# Test at early, middle, and late points in episode
|
||||||
|
frame_indices = [30, 150, 270]
|
||||||
|
sparse_center_targets = []
|
||||||
|
dense_center_targets = []
|
||||||
|
|
||||||
|
for frame_idx in frame_indices:
|
||||||
|
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": frame_idx,
|
||||||
|
"episode_index": 0,
|
||||||
|
"task": "test task",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
# Get target at center frame (index 4 in 9-frame observation window)
|
||||||
|
sparse_center_targets.append(obs["sparse_targets"][0, 4].item())
|
||||||
|
dense_center_targets.append(obs["dense_targets"][0, 4].item())
|
||||||
|
|
||||||
|
# Both sparse and dense targets should increase with frame index
|
||||||
|
assert sparse_center_targets[0] < sparse_center_targets[2], (
|
||||||
|
f"Early sparse target ({sparse_center_targets[0]}) should be < late ({sparse_center_targets[2]})"
|
||||||
|
)
|
||||||
|
assert dense_center_targets[0] < dense_center_targets[2], (
|
||||||
|
f"Early dense target ({dense_center_targets[0]}) should be < late ({dense_center_targets[2]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_progress_labels_exact_values(self, mock_clip_model):
|
||||||
|
"""Test that progress labels (stage.tau) are computed correctly for known positions."""
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
|
# Simple setup: 2 sparse stages, 4 dense stages, 100 frame episode
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=10, # Smaller gap for easier calculation
|
||||||
|
rewind_probability=0.0,
|
||||||
|
language_perturbation_probability=0.0,
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["A", "B"],
|
||||||
|
sparse_temporal_proportions=[0.5, 0.5],
|
||||||
|
dense_subtask_names=["d1", "d2", "d3", "d4"],
|
||||||
|
dense_temporal_proportions=[0.25, 0.25, 0.25, 0.25],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Episode: frames 0-99, sparse stages at [0-49], [50-99]
|
||||||
|
# Dense stages at [0-24], [25-49], [50-74], [75-99]
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 100,
|
||||||
|
"task": "test",
|
||||||
|
"sparse_subtask_names": ["A", "B"],
|
||||||
|
"sparse_subtask_start_frames": [0, 50],
|
||||||
|
"sparse_subtask_end_frames": [50, 100],
|
||||||
|
"dense_subtask_names": ["d1", "d2", "d3", "d4"],
|
||||||
|
"dense_subtask_start_frames": [0, 25, 50, 75],
|
||||||
|
"dense_subtask_end_frames": [25, 50, 75, 100],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||||
|
processor.train(True)
|
||||||
|
|
||||||
|
num_frames = config.num_frames
|
||||||
|
|
||||||
|
# Test at frame 50 (center of episode)
|
||||||
|
# With frame_gap=10, n_obs_steps=8:
|
||||||
|
# obs indices around frame 50: [10, 20, 30, 40, 50, 60, 70, 80, 90] (9 frames)
|
||||||
|
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": 50,
|
||||||
|
"episode_index": 0,
|
||||||
|
"task": "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
sparse_targets = obs["sparse_targets"][0] # (13,)
|
||||||
|
dense_targets = obs["dense_targets"][0] # (13,)
|
||||||
|
|
||||||
|
# First 9 frames are observation frames, last 4 are rewind placeholders (zeros when no rewind)
|
||||||
|
# Check that obs frames have non-zero targets
|
||||||
|
obs_sparse = sparse_targets[:9]
|
||||||
|
obs_dense = dense_targets[:9]
|
||||||
|
|
||||||
|
# Verify targets are monotonically increasing for observation frames
|
||||||
|
for i in range(1, 9):
|
||||||
|
assert obs_sparse[i] >= obs_sparse[i - 1], (
|
||||||
|
f"Sparse targets should be monotonic: {obs_sparse[i - 1].item():.3f} -> {obs_sparse[i].item():.3f}"
|
||||||
|
)
|
||||||
|
assert obs_dense[i] >= obs_dense[i - 1], (
|
||||||
|
f"Dense targets should be monotonic: {obs_dense[i - 1].item():.3f} -> {obs_dense[i].item():.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rewind slots should be zero when rewind is disabled
|
||||||
|
rewind_targets = sparse_targets[9:]
|
||||||
|
assert (rewind_targets == 0).all(), "Rewind slots should be zero when rewind is disabled"
|
||||||
|
|
||||||
|
# Check stage transitions: frame 50 is at boundary of sparse stage A->B
|
||||||
|
# Center frame (index 4) corresponds to actual frame 50
|
||||||
|
center_sparse = obs_sparse[4].item()
|
||||||
|
# At frame 50, sparse stage B starts, so target should be ~1.0 (stage 1 + tau 0)
|
||||||
|
assert 0.9 <= center_sparse <= 1.1, (
|
||||||
|
f"At sparse boundary, target should be ~1.0, got {center_sparse:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rewind_augmentation_applied(self, mock_clip_model):
|
||||||
|
"""Test that rewind augmentation correctly extends sequence and generates targets."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=10,
|
||||||
|
rewind_probability=1.0, # Always apply rewind
|
||||||
|
language_perturbation_probability=0.0,
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["A", "B"],
|
||||||
|
sparse_temporal_proportions=[0.5, 0.5],
|
||||||
|
dense_subtask_names=["d1", "d2"],
|
||||||
|
dense_temporal_proportions=[0.5, 0.5],
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 200,
|
||||||
|
"task": "test",
|
||||||
|
"sparse_subtask_names": ["A", "B"],
|
||||||
|
"sparse_subtask_start_frames": [0, 100],
|
||||||
|
"sparse_subtask_end_frames": [100, 200],
|
||||||
|
"dense_subtask_names": ["d1", "d2"],
|
||||||
|
"dense_subtask_start_frames": [0, 100],
|
||||||
|
"dense_subtask_end_frames": [100, 200],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||||
|
processor.train(True)
|
||||||
|
|
||||||
|
num_frames = config.num_frames # 13
|
||||||
|
|
||||||
|
# Test at frame 150 (center of bidirectional window)
|
||||||
|
# With n_obs_steps=8, half_steps=4, frame_gap=10:
|
||||||
|
# - Earliest obs frame = 150 - 4*10 = 110
|
||||||
|
# - Rewind can go back from 110 to frames like 100, 90, 80, 70
|
||||||
|
# - History available = 110 - 0 = 110, so max rewind = 110/10 = 11 (capped at 4)
|
||||||
|
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": 150,
|
||||||
|
"episode_index": 0,
|
||||||
|
"task": "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Seed random for reproducibility
|
||||||
|
random.seed(42)
|
||||||
|
result = processor(transition)
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
|
||||||
|
lengths = obs["lengths"][0].item()
|
||||||
|
sparse_targets = obs["sparse_targets"][0]
|
||||||
|
|
||||||
|
# With rewind_probability=1.0 and enough history, lengths should be > 9 (9 obs + some rewind)
|
||||||
|
assert lengths > 9, f"With rewind enabled, lengths should be > 9, got {lengths}"
|
||||||
|
assert lengths <= num_frames, f"Lengths should not exceed total frames {num_frames}, got {lengths}"
|
||||||
|
|
||||||
|
# Rewind targets should be non-zero for frames within valid length
|
||||||
|
n_obs_frames = 9
|
||||||
|
rewind_count = lengths - n_obs_frames
|
||||||
|
|
||||||
|
if rewind_count > 0:
|
||||||
|
# Check that rewind frames have targets
|
||||||
|
rewind_targets = sparse_targets[n_obs_frames : n_obs_frames + rewind_count]
|
||||||
|
# Rewind frames are from BEFORE the earliest obs frame (110)
|
||||||
|
# These frames (100, 90, 80, 70) are earlier in the episode
|
||||||
|
earliest_obs_target = sparse_targets[0].item() # Frame 110
|
||||||
|
|
||||||
|
# Rewind targets should be less than earliest obs (they're from earlier frames)
|
||||||
|
for i, rt in enumerate(rewind_targets):
|
||||||
|
assert rt.item() < earliest_obs_target, (
|
||||||
|
f"Rewind target {i} ({rt.item():.3f}) should be < earliest obs ({earliest_obs_target:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rewind targets should be decreasing (going further back in time)
|
||||||
|
for i in range(1, len(rewind_targets)):
|
||||||
|
assert rewind_targets[i] <= rewind_targets[i - 1], (
|
||||||
|
f"Rewind targets should decrease: {rewind_targets[i - 1].item():.3f} -> {rewind_targets[i].item():.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_sequence_target_consistency(self, mock_clip_model):
|
||||||
|
"""Test that the full sequence of targets is consistent with frame positions."""
|
||||||
|
from lerobot.policies.sarm.processor_sarm import SARMEncodingProcessorStep
|
||||||
|
from lerobot.policies.sarm.sarm_utils import find_stage_and_tau
|
||||||
|
|
||||||
|
config = MockConfig(
|
||||||
|
n_obs_steps=8,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=10,
|
||||||
|
rewind_probability=0.0,
|
||||||
|
language_perturbation_probability=0.0,
|
||||||
|
annotation_mode="dual",
|
||||||
|
sparse_subtask_names=["s1", "s2", "s3"],
|
||||||
|
sparse_temporal_proportions=[0.33, 0.34, 0.33],
|
||||||
|
dense_subtask_names=["d1", "d2"],
|
||||||
|
dense_temporal_proportions=[0.5, 0.5],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3 sparse stages: [0-33), [33-66), [66-99]
|
||||||
|
# 2 dense stages: [0-50), [50-100)
|
||||||
|
episodes = [
|
||||||
|
{
|
||||||
|
"dataset_from_index": 0,
|
||||||
|
"dataset_to_index": 100,
|
||||||
|
"task": "test",
|
||||||
|
"sparse_subtask_names": ["s1", "s2", "s3"],
|
||||||
|
"sparse_subtask_start_frames": [0, 33, 66],
|
||||||
|
"sparse_subtask_end_frames": [33, 66, 100],
|
||||||
|
"dense_subtask_names": ["d1", "d2"],
|
||||||
|
"dense_subtask_start_frames": [0, 50],
|
||||||
|
"dense_subtask_end_frames": [50, 100],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dataset_meta = MockDatasetMeta(episodes)
|
||||||
|
|
||||||
|
processor = SARMEncodingProcessorStep(config=config, dataset_meta=dataset_meta)
|
||||||
|
processor.train(True)
|
||||||
|
|
||||||
|
num_frames = config.num_frames
|
||||||
|
|
||||||
|
# Test at frame 50 (middle of episode)
|
||||||
|
dummy_image = np.random.rand(num_frames, 3, 224, 224).astype(np.float32)
|
||||||
|
dummy_state = np.random.rand(num_frames, 6).astype(np.float32)
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {
|
||||||
|
config.image_key: dummy_image,
|
||||||
|
config.state_key: dummy_state,
|
||||||
|
},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"index": 50,
|
||||||
|
"episode_index": 0,
|
||||||
|
"task": "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
obs = result[TransitionKey.OBSERVATION]
|
||||||
|
sparse_targets = obs["sparse_targets"][0]
|
||||||
|
dense_targets = obs["dense_targets"][0]
|
||||||
|
|
||||||
|
# Manually compute expected targets for observation frames
|
||||||
|
# With frame_gap=10, n_obs_steps=8, center at 50:
|
||||||
|
# obs frames: [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||||
|
expected_obs_frames = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||||||
|
|
||||||
|
sparse_names = ["s1", "s2", "s3"]
|
||||||
|
sparse_starts = [0, 33, 66]
|
||||||
|
sparse_ends = [33, 66, 100]
|
||||||
|
sparse_props = {"s1": 0.33, "s2": 0.34, "s3": 0.33}
|
||||||
|
|
||||||
|
dense_names = ["d1", "d2"]
|
||||||
|
dense_starts = [0, 50]
|
||||||
|
dense_ends = [50, 100]
|
||||||
|
dense_props = {"d1": 0.5, "d2": 0.5}
|
||||||
|
|
||||||
|
for i, frame in enumerate(expected_obs_frames):
|
||||||
|
expected_sparse = find_stage_and_tau(
|
||||||
|
frame,
|
||||||
|
100,
|
||||||
|
sparse_names,
|
||||||
|
sparse_starts,
|
||||||
|
sparse_ends,
|
||||||
|
sparse_names,
|
||||||
|
sparse_props,
|
||||||
|
return_combined=True,
|
||||||
|
)
|
||||||
|
expected_dense = find_stage_and_tau(
|
||||||
|
frame,
|
||||||
|
100,
|
||||||
|
dense_names,
|
||||||
|
dense_starts,
|
||||||
|
dense_ends,
|
||||||
|
dense_names,
|
||||||
|
dense_props,
|
||||||
|
return_combined=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_sparse = sparse_targets[i].item()
|
||||||
|
actual_dense = dense_targets[i].item()
|
||||||
|
|
||||||
|
assert abs(actual_sparse - expected_sparse) < 0.01, (
|
||||||
|
f"Frame {frame}: sparse mismatch {actual_sparse:.3f} vs expected {expected_sparse:.3f}"
|
||||||
|
)
|
||||||
|
assert abs(actual_dense - expected_dense) < 0.01, (
|
||||||
|
f"Frame {frame}: dense mismatch {actual_dense:.3f} vs expected {expected_dense:.3f}"
|
||||||
|
)
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
from lerobot.data_processing.sarm_annotations.subtask_annotation import (
|
||||||
|
Subtask,
|
||||||
|
SubtaskAnnotation,
|
||||||
|
Timestamp,
|
||||||
|
compute_temporal_proportions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||||
|
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||||
|
return SubtaskAnnotation(
|
||||||
|
subtasks=[
|
||||||
|
Subtask(
|
||||||
|
name=name,
|
||||||
|
timestamps=Timestamp(
|
||||||
|
start=f"{start // 60:02d}:{start % 60:02d}", end=f"{end // 60:02d}:{end % 60:02d}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for name, start, end in subtasks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeTemporalProportions:
|
||||||
|
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||||
|
|
||||||
|
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||||
|
|
||||||
|
Key insight: This averages the PROPORTION of each subtask within each trajectory,
|
||||||
|
giving equal weight to all trajectories regardless of absolute length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_basic_two_trajectories_equal_proportions(self):
|
||||||
|
"""Test with two trajectories that have equal proportions."""
|
||||||
|
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||||
|
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||||
|
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||||
|
annotations = {
|
||||||
|
0: make_annotation([("subtask1", 0, 50), ("subtask2", 50, 100)]),
|
||||||
|
1: make_annotation([("subtask1", 0, 100), ("subtask2", 100, 200)]),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
|
# Both should be 0.5
|
||||||
|
assert abs(result["subtask1"] - 0.5) < 1e-6
|
||||||
|
assert abs(result["subtask2"] - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_paper_example_different_from_avg_durations(self):
|
||||||
|
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||||
|
|
||||||
|
This is the key test showing the difference between:
|
||||||
|
- Paper formula: average of (L_i,k / T_i)
|
||||||
|
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||||
|
"""
|
||||||
|
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||||
|
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||||
|
annotations = {
|
||||||
|
0: make_annotation([("subtask1", 0, 80), ("subtask2", 80, 100)]),
|
||||||
|
1: make_annotation([("subtask1", 0, 40), ("subtask2", 40, 200)]),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
|
# Paper formula:
|
||||||
|
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||||
|
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
|
||||||
|
assert abs(result["subtask1"] - 0.5) < 1e-6
|
||||||
|
assert abs(result["subtask2"] - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_single_trajectory(self):
|
||||||
|
"""Test with a single trajectory."""
|
||||||
|
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||||
|
annotations = {
|
||||||
|
0: make_annotation([("reach", 0, 30), ("grasp", 30, 50), ("lift", 50, 100)]),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
|
assert abs(result["reach"] - 0.3) < 1e-6
|
||||||
|
assert abs(result["grasp"] - 0.2) < 1e-6
|
||||||
|
assert abs(result["lift"] - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_sum_to_one(self):
|
||||||
|
"""Test that proportions always sum to 1."""
|
||||||
|
# Three episodes with varying proportions
|
||||||
|
annotations = {
|
||||||
|
0: make_annotation([("a", 0, 10), ("b", 10, 50), ("c", 50, 100)]), # 0.1, 0.4, 0.5
|
||||||
|
1: make_annotation([("a", 0, 20), ("b", 20, 70), ("c", 70, 100)]), # 0.2, 0.5, 0.3
|
||||||
|
2: make_annotation([("a", 0, 30), ("b", 30, 90), ("c", 90, 100)]), # 0.3, 0.6, 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
|
total = sum(result.values())
|
||||||
|
assert abs(total - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_empty_annotations_returns_empty(self):
|
||||||
|
"""Test that empty annotations returns empty dict."""
|
||||||
|
result = compute_temporal_proportions({})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_uniform_proportions(self):
|
||||||
|
"""Test with uniform proportions across subtasks."""
|
||||||
|
# Each subtask takes 25% of each episode
|
||||||
|
annotations = {
|
||||||
|
0: make_annotation([("a", 0, 25), ("b", 25, 50), ("c", 50, 75), ("d", 75, 100)]),
|
||||||
|
1: make_annotation([("a", 0, 50), ("b", 50, 100), ("c", 100, 150), ("d", 150, 200)]),
|
||||||
|
}
|
||||||
|
|
||||||
|
result = compute_temporal_proportions(annotations)
|
||||||
|
|
||||||
|
for name in ["a", "b", "c", "d"]:
|
||||||
|
assert abs(result[name] - 0.25) < 1e-6
|
||||||
@@ -0,0 +1,615 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.sarm.sarm_utils import (
|
||||||
|
apply_rewind_augmentation,
|
||||||
|
compute_absolute_indices,
|
||||||
|
compute_tau,
|
||||||
|
find_stage_and_tau,
|
||||||
|
normalize_stage_tau,
|
||||||
|
temporal_proportions_to_breakpoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressLabelsWithModes:
|
||||||
|
"""End-to-end tests for progress label generation in different modes."""
|
||||||
|
|
||||||
|
def test_sparse_mode_single_stage(self):
|
||||||
|
"""Sparse mode with single stage should give linear progress."""
|
||||||
|
episode_length = 300
|
||||||
|
global_names = ["task"]
|
||||||
|
proportions = {"task": 1.0}
|
||||||
|
|
||||||
|
# Test at various frames
|
||||||
|
for frame in [0, 100, 200, 299]:
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
frame, episode_length, None, None, None, global_names, proportions
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tau = frame / (episode_length - 1)
|
||||||
|
assert stage == 0
|
||||||
|
assert abs(tau - expected_tau) < 1e-5
|
||||||
|
|
||||||
|
def test_sparse_mode_multi_stage(self):
|
||||||
|
"""Sparse mode with multiple stages."""
|
||||||
|
global_names = ["reach", "grasp", "lift", "place"]
|
||||||
|
proportions = {"reach": 0.2, "grasp": 0.2, "lift": 0.3, "place": 0.3}
|
||||||
|
|
||||||
|
subtask_names = ["reach", "grasp", "lift", "place"]
|
||||||
|
subtask_starts = [0, 60, 120, 210]
|
||||||
|
subtask_ends = [59, 119, 209, 299]
|
||||||
|
|
||||||
|
# Check stages are correctly identified
|
||||||
|
stage_at_30, _ = find_stage_and_tau(
|
||||||
|
30, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage_at_30 == 0
|
||||||
|
|
||||||
|
stage_at_90, _ = find_stage_and_tau(
|
||||||
|
90, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage_at_90 == 1
|
||||||
|
|
||||||
|
stage_at_150, _ = find_stage_and_tau(
|
||||||
|
150, 300, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage_at_150 == 2
|
||||||
|
|
||||||
|
def test_dense_mode_more_stages(self):
|
||||||
|
"""Dense mode should work with more fine-grained stages."""
|
||||||
|
global_names = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||||
|
proportions = dict.fromkeys(global_names, 1 / 8)
|
||||||
|
|
||||||
|
subtask_names = global_names
|
||||||
|
subtask_starts = [i * 50 for i in range(8)]
|
||||||
|
subtask_ends = [(i + 1) * 50 - 1 for i in range(8)]
|
||||||
|
|
||||||
|
# Each stage should occupy 50 frames
|
||||||
|
for stage_idx in range(8):
|
||||||
|
mid_frame = stage_idx * 50 + 25
|
||||||
|
stage, _ = find_stage_and_tau(
|
||||||
|
mid_frame, 400, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == stage_idx
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeAbsoluteIndices:
|
||||||
|
"""Tests for compute_absolute_indices (bidirectional sampling)."""
|
||||||
|
|
||||||
|
def test_no_clamping_when_in_middle(self):
|
||||||
|
"""When frame is in middle of episode, no clamping should occur."""
|
||||||
|
frame_idx = 300
|
||||||
|
ep_start = 0
|
||||||
|
ep_end = 1000
|
||||||
|
n_obs_steps = 8
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||||
|
|
||||||
|
# All should be valid (no out of bounds)
|
||||||
|
assert out_of_bounds.sum() == 0
|
||||||
|
|
||||||
|
# Check bidirectional indices: [-120, -90, -60, -30, 0, 30, 60, 90, 120] from center
|
||||||
|
half_steps = n_obs_steps // 2
|
||||||
|
expected = (
|
||||||
|
[frame_idx - frame_gap * i for i in range(half_steps, 0, -1)]
|
||||||
|
+ [frame_idx]
|
||||||
|
+ [frame_idx + frame_gap * i for i in range(1, half_steps + 1)]
|
||||||
|
)
|
||||||
|
assert indices.tolist() == expected
|
||||||
|
|
||||||
|
# Center frame (index 4) should be frame_idx
|
||||||
|
assert indices[half_steps] == frame_idx
|
||||||
|
|
||||||
|
def test_clamping_at_episode_start(self):
|
||||||
|
"""Early frames should be clamped to episode start."""
|
||||||
|
frame_idx = 50 # Not enough history for full past window
|
||||||
|
ep_start = 0
|
||||||
|
ep_end = 1000
|
||||||
|
n_obs_steps = 8
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||||
|
|
||||||
|
# Some past frames should be clamped (out_of_bounds = 1)
|
||||||
|
assert out_of_bounds.sum() > 0
|
||||||
|
|
||||||
|
# All indices should be >= ep_start
|
||||||
|
assert (indices >= ep_start).all()
|
||||||
|
|
||||||
|
# Center index should be frame_idx
|
||||||
|
half_steps = n_obs_steps // 2
|
||||||
|
assert indices[half_steps] == frame_idx
|
||||||
|
|
||||||
|
def test_clamping_at_episode_end(self):
|
||||||
|
"""Late frames should be clamped to episode end."""
|
||||||
|
frame_idx = 950 # Not enough future for full window
|
||||||
|
ep_start = 0
|
||||||
|
ep_end = 1000
|
||||||
|
n_obs_steps = 8
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
indices, out_of_bounds = compute_absolute_indices(frame_idx, ep_start, ep_end, n_obs_steps, frame_gap)
|
||||||
|
|
||||||
|
# Some future frames should be clamped
|
||||||
|
assert out_of_bounds.sum() > 0
|
||||||
|
|
||||||
|
# All indices should be < ep_end
|
||||||
|
assert (indices < ep_end).all()
|
||||||
|
|
||||||
|
# Center index should be frame_idx
|
||||||
|
half_steps = n_obs_steps // 2
|
||||||
|
assert indices[half_steps] == frame_idx
|
||||||
|
|
||||||
|
def test_sequence_is_monotonic(self):
|
||||||
|
"""Frame indices should be monotonically increasing."""
|
||||||
|
for frame_idx in [50, 100, 300, 950]:
|
||||||
|
indices, _ = compute_absolute_indices(frame_idx, 0, 1000, 8, 30)
|
||||||
|
|
||||||
|
# Check monotonic (non-decreasing due to clamping)
|
||||||
|
diffs = indices[1:] - indices[:-1]
|
||||||
|
assert (diffs >= 0).all(), f"Non-monotonic at frame {frame_idx}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeTau:
|
||||||
|
"""Tests for compute_tau (within-subtask progress).
|
||||||
|
|
||||||
|
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_at_start(self):
|
||||||
|
"""τ should be 0 at subtask start."""
|
||||||
|
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
|
||||||
|
assert tau == 0.0
|
||||||
|
|
||||||
|
def test_at_end(self):
|
||||||
|
"""τ should be 1 at subtask end."""
|
||||||
|
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
|
||||||
|
assert tau == 1.0
|
||||||
|
|
||||||
|
def test_at_middle(self):
|
||||||
|
"""τ should be 0.5 at subtask midpoint."""
|
||||||
|
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
|
||||||
|
assert abs(tau - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_quarter_progress(self):
|
||||||
|
"""Test τ at 25% through subtask."""
|
||||||
|
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
|
||||||
|
assert abs(tau - 0.25) < 1e-6
|
||||||
|
|
||||||
|
def test_zero_duration_subtask(self):
|
||||||
|
"""τ should be 1.0 for zero-duration subtask."""
|
||||||
|
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
|
||||||
|
assert tau == 1.0
|
||||||
|
|
||||||
|
def test_clamps_below_zero(self):
|
||||||
|
"""τ should be clamped to 0 if frame is before subtask."""
|
||||||
|
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
|
||||||
|
assert tau == 0.0
|
||||||
|
|
||||||
|
def test_clamps_above_one(self):
|
||||||
|
"""τ should be clamped to 1 if frame is after subtask."""
|
||||||
|
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
|
||||||
|
assert tau == 1.0
|
||||||
|
|
||||||
|
def test_float_inputs(self):
|
||||||
|
"""Test with float frame indices (from interpolation)."""
|
||||||
|
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
|
||||||
|
expected = (25.5 - 10.0) / (50.0 - 10.0)
|
||||||
|
assert abs(tau - expected) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindStageAndTau:
|
||||||
|
"""Tests for find_stage_and_tau logic.
|
||||||
|
|
||||||
|
This function is the core of progress label computation. It determines
|
||||||
|
which stage a frame belongs to and the within-stage progress (tau).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_single_stage_mode_linear_progress(self):
|
||||||
|
"""Single-stage mode should give linear progress from 0 to 1."""
|
||||||
|
episode_length = 100
|
||||||
|
|
||||||
|
# Frame 0 -> tau = 0
|
||||||
|
stage, tau = find_stage_and_tau(0, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||||
|
assert stage == 0
|
||||||
|
assert abs(tau - 0.0) < 1e-6
|
||||||
|
|
||||||
|
# Frame 50 -> tau = 0.505 (50/99)
|
||||||
|
stage, tau = find_stage_and_tau(50, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||||
|
assert stage == 0
|
||||||
|
assert abs(tau - 50 / 99) < 1e-6
|
||||||
|
|
||||||
|
# Frame 99 -> tau = 1.0
|
||||||
|
stage, tau = find_stage_and_tau(99, episode_length, None, None, None, ["task"], {"task": 1.0})
|
||||||
|
assert stage == 0
|
||||||
|
assert abs(tau - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_multi_stage_within_subtask(self):
|
||||||
|
"""Test finding stage when frame is within a subtask."""
|
||||||
|
global_names = ["reach", "grasp", "lift"]
|
||||||
|
proportions = {"reach": 0.3, "grasp": 0.2, "lift": 0.5}
|
||||||
|
|
||||||
|
subtask_names = ["reach", "grasp", "lift"]
|
||||||
|
subtask_starts = [0, 30, 50]
|
||||||
|
subtask_ends = [29, 49, 99]
|
||||||
|
|
||||||
|
# Frame 15 in "reach" stage (index 0)
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
15, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 0
|
||||||
|
assert abs(tau - 15 / 29) < 1e-6
|
||||||
|
|
||||||
|
# Frame 40 in "grasp" stage (index 1)
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
40, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 1
|
||||||
|
# tau = (40 - 30) / (49 - 30) = 10/19
|
||||||
|
assert abs(tau - 10 / 19) < 1e-6
|
||||||
|
|
||||||
|
# Frame 75 in "lift" stage (index 2)
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
75, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 2
|
||||||
|
# tau = (75 - 50) / (99 - 50) = 25/49
|
||||||
|
assert abs(tau - 25 / 49) < 1e-6
|
||||||
|
|
||||||
|
def test_frame_at_subtask_boundaries(self):
|
||||||
|
"""Test frames exactly at subtask boundaries."""
|
||||||
|
global_names = ["a", "b"]
|
||||||
|
proportions = {"a": 0.5, "b": 0.5}
|
||||||
|
|
||||||
|
subtask_names = ["a", "b"]
|
||||||
|
subtask_starts = [0, 50]
|
||||||
|
subtask_ends = [49, 99]
|
||||||
|
|
||||||
|
# Frame at start of first subtask
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
0, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 0
|
||||||
|
assert tau == 0.0
|
||||||
|
|
||||||
|
# Frame at end of first subtask
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
49, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 0
|
||||||
|
assert tau == 1.0
|
||||||
|
|
||||||
|
# Frame at start of second subtask
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
50, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 1
|
||||||
|
assert tau == 0.0
|
||||||
|
|
||||||
|
def test_frame_after_last_subtask(self):
|
||||||
|
"""Frames after last subtask should return last stage with high tau."""
|
||||||
|
global_names = ["a", "b"]
|
||||||
|
proportions = {"a": 0.5, "b": 0.5}
|
||||||
|
|
||||||
|
subtask_names = ["a", "b"]
|
||||||
|
subtask_starts = [0, 30]
|
||||||
|
subtask_ends = [29, 59]
|
||||||
|
|
||||||
|
# Frame 80 is after last subtask
|
||||||
|
stage, tau = find_stage_and_tau(
|
||||||
|
80, 100, subtask_names, subtask_starts, subtask_ends, global_names, proportions
|
||||||
|
)
|
||||||
|
assert stage == 1 # Last stage
|
||||||
|
assert tau == 0.999 # Nearly complete
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndProgressLabeling:
|
||||||
|
"""End-to-end tests for progress label computation using normalize_stage_tau."""
|
||||||
|
|
||||||
|
def test_consistent_semantic_meaning(self):
|
||||||
|
"""Test that same subtask completion maps to same progress across trajectories.
|
||||||
|
|
||||||
|
This is the key semantic property: "end of subtask 1" should always
|
||||||
|
mean the same progress value regardless of trajectory speed.
|
||||||
|
"""
|
||||||
|
proportions = [0.3, 0.5, 0.2]
|
||||||
|
|
||||||
|
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
|
||||||
|
tau_fast = compute_tau(30, 0, 30) # = 1.0
|
||||||
|
y_fast = normalize_stage_tau(0 + tau_fast, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
|
||||||
|
tau_slow = compute_tau(90, 0, 90) # = 1.0
|
||||||
|
y_slow = normalize_stage_tau(0 + tau_slow, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Both should map to same progress (0.3 = end of subtask 1)
|
||||||
|
assert abs(y_fast - y_slow) < 1e-6
|
||||||
|
assert abs(y_fast - 0.3) < 1e-6
|
||||||
|
|
||||||
|
def test_monotonic_within_subtask(self):
|
||||||
|
"""Test that progress is monotonically increasing within a subtask."""
|
||||||
|
proportions = [0.4, 0.6]
|
||||||
|
|
||||||
|
prev_y = -1
|
||||||
|
for tau in np.linspace(0, 1, 11):
|
||||||
|
y = normalize_stage_tau(0 + tau, temporal_proportions=proportions)
|
||||||
|
assert y > prev_y or (tau == 0 and y == 0)
|
||||||
|
prev_y = y
|
||||||
|
|
||||||
|
def test_continuous_across_subtasks(self):
|
||||||
|
"""Test that progress is continuous at subtask boundaries."""
|
||||||
|
proportions = [0.3, 0.5, 0.2]
|
||||||
|
|
||||||
|
# End of subtask 0 (stage=0, tau=1.0) -> stage.tau = 1.0
|
||||||
|
y_end_0 = normalize_stage_tau(0 + 1.0, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Start of subtask 1 (stage=1, tau=0.0) -> stage.tau = 1.0
|
||||||
|
y_start_1 = normalize_stage_tau(1 + 0.0, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Should be equal (P_1 = 0.3)
|
||||||
|
assert abs(y_end_0 - y_start_1) < 1e-6
|
||||||
|
|
||||||
|
# End of subtask 1 (stage=1, tau=1.0) -> stage.tau = 2.0
|
||||||
|
y_end_1 = normalize_stage_tau(1 + 1.0, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Start of subtask 2 (stage=2, tau=0.0) -> stage.tau = 2.0
|
||||||
|
y_start_2 = normalize_stage_tau(2 + 0.0, temporal_proportions=proportions)
|
||||||
|
|
||||||
|
# Should be equal (P_2 = 0.8)
|
||||||
|
assert abs(y_end_1 - y_start_2) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class TestTemporalProportionsToBreakpoints:
|
||||||
|
"""Tests for temporal_proportions_to_breakpoints.
|
||||||
|
|
||||||
|
Converts temporal proportions to cumulative breakpoints for normalization.
|
||||||
|
Example: [0.3, 0.5, 0.2] -> [0.0, 0.3, 0.8, 1.0]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_basic_conversion(self):
|
||||||
|
"""Test basic conversion from proportions to breakpoints."""
|
||||||
|
proportions = [0.3, 0.5, 0.2]
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||||
|
|
||||||
|
assert breakpoints is not None
|
||||||
|
assert len(breakpoints) == 4
|
||||||
|
assert breakpoints[0] == 0.0
|
||||||
|
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||||
|
assert abs(breakpoints[2] - 0.8) < 1e-6
|
||||||
|
assert breakpoints[3] == 1.0
|
||||||
|
|
||||||
|
def test_dict_input(self):
|
||||||
|
"""Test with dict input."""
|
||||||
|
proportions = {"a": 0.25, "b": 0.25, "c": 0.5}
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||||
|
|
||||||
|
assert breakpoints is not None
|
||||||
|
assert len(breakpoints) == 4
|
||||||
|
assert breakpoints[0] == 0.0
|
||||||
|
assert breakpoints[-1] == 1.0
|
||||||
|
|
||||||
|
def test_dict_with_subtask_names_order(self):
|
||||||
|
"""Test that subtask_names determines order for dict input."""
|
||||||
|
proportions = {"c": 0.5, "a": 0.2, "b": 0.3} # Dict order
|
||||||
|
subtask_names = ["a", "b", "c"] # Different order
|
||||||
|
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(proportions, subtask_names)
|
||||||
|
|
||||||
|
# Breakpoints should follow subtask_names order: a=0.2, b=0.3, c=0.5
|
||||||
|
assert abs(breakpoints[1] - 0.2) < 1e-6 # a
|
||||||
|
assert abs(breakpoints[2] - 0.5) < 1e-6 # a + b = 0.5
|
||||||
|
assert breakpoints[3] == 1.0 # a + b + c = 1.0
|
||||||
|
|
||||||
|
def test_uniform_proportions(self):
|
||||||
|
"""Test with uniform proportions."""
|
||||||
|
proportions = [0.25, 0.25, 0.25, 0.25]
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||||
|
|
||||||
|
expected = [0.0, 0.25, 0.5, 0.75, 1.0]
|
||||||
|
for i, (bp, exp) in enumerate(zip(breakpoints, expected, strict=True)):
|
||||||
|
assert abs(bp - exp) < 1e-6, f"Breakpoint {i} mismatch"
|
||||||
|
|
||||||
|
def test_none_input(self):
|
||||||
|
"""Test that None input returns None."""
|
||||||
|
result = temporal_proportions_to_breakpoints(None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_normalization(self):
|
||||||
|
"""Test that non-normalized proportions are normalized."""
|
||||||
|
# Proportions sum to 2.0, not 1.0
|
||||||
|
proportions = [0.6, 1.0, 0.4]
|
||||||
|
breakpoints = temporal_proportions_to_breakpoints(proportions)
|
||||||
|
|
||||||
|
# Should be normalized: [0.3, 0.5, 0.2] -> [0, 0.3, 0.8, 1.0]
|
||||||
|
assert breakpoints[-1] == 1.0
|
||||||
|
assert abs(breakpoints[1] - 0.3) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeStageTau:
|
||||||
|
"""Tests for normalize_stage_tau.
|
||||||
|
|
||||||
|
Normalizes stage+tau values to [0, 1] using breakpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_linear_fallback(self):
|
||||||
|
"""Test linear normalization when only num_stages is provided."""
|
||||||
|
# 4 stages, linear: [0, 0.25, 0.5, 0.75, 1.0]
|
||||||
|
|
||||||
|
# Stage 0 start
|
||||||
|
assert normalize_stage_tau(0.0, num_stages=4) == 0.0
|
||||||
|
|
||||||
|
# Stage 0 end / Stage 1 start
|
||||||
|
assert abs(normalize_stage_tau(1.0, num_stages=4) - 0.25) < 1e-6
|
||||||
|
|
||||||
|
# Stage 1 middle
|
||||||
|
assert abs(normalize_stage_tau(1.5, num_stages=4) - 0.375) < 1e-6
|
||||||
|
|
||||||
|
# Stage 3 end
|
||||||
|
assert normalize_stage_tau(4.0, num_stages=4) == 1.0
|
||||||
|
|
||||||
|
def test_with_custom_breakpoints(self):
|
||||||
|
"""Test with custom breakpoints."""
|
||||||
|
# Non-linear breakpoints
|
||||||
|
breakpoints = [0.0, 0.1, 0.5, 1.0] # 3 stages
|
||||||
|
|
||||||
|
# Stage 0: maps [0, 1) to [0.0, 0.1)
|
||||||
|
assert abs(normalize_stage_tau(0.5, breakpoints=breakpoints) - 0.05) < 1e-6
|
||||||
|
|
||||||
|
# Stage 1: maps [1, 2) to [0.1, 0.5)
|
||||||
|
assert abs(normalize_stage_tau(1.5, breakpoints=breakpoints) - 0.3) < 1e-6
|
||||||
|
|
||||||
|
# Stage 2: maps [2, 3) to [0.5, 1.0)
|
||||||
|
assert abs(normalize_stage_tau(2.5, breakpoints=breakpoints) - 0.75) < 1e-6
|
||||||
|
|
||||||
|
def test_with_temporal_proportions(self):
|
||||||
|
"""Test with temporal proportions (auto-computed breakpoints)."""
|
||||||
|
proportions = {"a": 0.2, "b": 0.3, "c": 0.5}
|
||||||
|
subtask_names = ["a", "b", "c"]
|
||||||
|
|
||||||
|
# Stage 0 end should map to 0.2
|
||||||
|
result = normalize_stage_tau(1.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||||
|
assert abs(result - 0.2) < 1e-6
|
||||||
|
|
||||||
|
# Stage 1 end should map to 0.5
|
||||||
|
result = normalize_stage_tau(2.0, temporal_proportions=proportions, subtask_names=subtask_names)
|
||||||
|
assert abs(result - 0.5) < 1e-6
|
||||||
|
|
||||||
|
def test_tensor_input(self):
|
||||||
|
"""Test with tensor input."""
|
||||||
|
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0])
|
||||||
|
breakpoints = [0.0, 0.3, 0.8, 1.0] # 3 stages
|
||||||
|
|
||||||
|
result = normalize_stage_tau(x, breakpoints=breakpoints)
|
||||||
|
|
||||||
|
assert isinstance(result, torch.Tensor)
|
||||||
|
assert result.shape == x.shape
|
||||||
|
assert abs(result[0].item() - 0.0) < 1e-6
|
||||||
|
assert abs(result[2].item() - 0.3) < 1e-6 # End of stage 0
|
||||||
|
assert abs(result[4].item() - 0.8) < 1e-6 # End of stage 1
|
||||||
|
|
||||||
|
def test_clamping(self):
|
||||||
|
"""Test that output is clamped to [0, 1]."""
|
||||||
|
# Below 0
|
||||||
|
assert normalize_stage_tau(-0.5, num_stages=4) == 0.0
|
||||||
|
|
||||||
|
# Above num_stages
|
||||||
|
assert normalize_stage_tau(5.0, num_stages=4) == 1.0
|
||||||
|
|
||||||
|
def test_batch_tensor(self):
|
||||||
|
"""Test with batched tensor."""
|
||||||
|
x = torch.tensor([[0.0, 1.0, 2.0], [0.5, 1.5, 2.5]]) # (2, 3)
|
||||||
|
|
||||||
|
result = normalize_stage_tau(x, num_stages=3)
|
||||||
|
|
||||||
|
assert result.shape == (2, 3)
|
||||||
|
assert (result >= 0).all()
|
||||||
|
assert (result <= 1).all()
|
||||||
|
|
||||||
|
def test_requires_one_of_inputs(self):
|
||||||
|
"""Test that at least one input method is required."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
normalize_stage_tau(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRewindAugmentation:
|
||||||
|
"""Tests for rewind augmentation logic with bidirectional observation sampling.
|
||||||
|
|
||||||
|
Rewind appends frames before the earliest observation frame, going backwards.
|
||||||
|
With bidirectional sampling centered at frame_idx:
|
||||||
|
- Earliest obs frame = frame_idx - half_steps * frame_gap
|
||||||
|
- Rewind goes backwards from that point
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_rewind_indices_go_backwards_from_earliest_obs(self):
|
||||||
|
"""Rewind indices should go backwards from earliest observation frame."""
|
||||||
|
frame_idx = 300 # Center of bidirectional window
|
||||||
|
ep_start = 0
|
||||||
|
n_obs_steps = 4 # half_steps = 2
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
# Earliest obs frame = 300 - 2*30 = 240
|
||||||
|
# Rewind goes backwards: 210, 180
|
||||||
|
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||||
|
frame_idx,
|
||||||
|
ep_start,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
max_rewind_steps=2,
|
||||||
|
frame_gap=frame_gap,
|
||||||
|
rewind_step=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rewind_step == 2
|
||||||
|
assert len(rewind_indices) == 2
|
||||||
|
# First rewind frame is closest to obs window, second is further back
|
||||||
|
assert rewind_indices[0] == 210 # 240 - 30
|
||||||
|
assert rewind_indices[1] == 180 # 240 - 60
|
||||||
|
assert rewind_indices[0] > rewind_indices[1], "Rewind should be descending"
|
||||||
|
|
||||||
|
def test_rewind_goes_backward_through_history(self):
|
||||||
|
"""Rewind frames should go backward before the observation window."""
|
||||||
|
frame_idx = 450 # Center of bidirectional window
|
||||||
|
ep_start = 0
|
||||||
|
n_obs_steps = 8 # half_steps = 4
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
# Earliest obs frame = 450 - 4*30 = 330
|
||||||
|
# Rewind from 330: [300, 270, 240]
|
||||||
|
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||||
|
frame_idx,
|
||||||
|
ep_start,
|
||||||
|
n_obs_steps=n_obs_steps,
|
||||||
|
max_rewind_steps=4,
|
||||||
|
frame_gap=frame_gap,
|
||||||
|
rewind_step=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rewind_step == 3
|
||||||
|
expected = [300, 270, 240] # Going backwards from 330
|
||||||
|
assert rewind_indices == expected
|
||||||
|
|
||||||
|
def test_no_rewind_when_obs_window_at_episode_start(self):
|
||||||
|
"""No rewind when observation window reaches episode start."""
|
||||||
|
frame_idx = 120 # Center of window
|
||||||
|
ep_start = 0
|
||||||
|
n_obs_steps = 8 # half_steps = 4
|
||||||
|
frame_gap = 30
|
||||||
|
|
||||||
|
# Earliest obs frame = 120 - 4*30 = 0 (at episode start)
|
||||||
|
rewind_step, rewind_indices = apply_rewind_augmentation(
|
||||||
|
frame_idx, ep_start, n_obs_steps=n_obs_steps, max_rewind_steps=4, frame_gap=frame_gap
|
||||||
|
)
|
||||||
|
|
||||||
|
# No room for rewind
|
||||||
|
assert rewind_step == 0
|
||||||
|
assert rewind_indices == []
|
||||||
|
|
||||||
|
def test_rewind_targets_are_decreasing(self):
|
||||||
|
"""Progress targets for rewind frames should be decreasing."""
|
||||||
|
# Simulate progress values
|
||||||
|
obs_progress = [0.1, 0.2, 0.3, 0.4, 0.5] # Forward progress
|
||||||
|
|
||||||
|
# Rewind reverses progress
|
||||||
|
rewind_indices = [4, 3, 2] # Go backwards through indices
|
||||||
|
rewind_progress = [obs_progress[i] for i in rewind_indices]
|
||||||
|
|
||||||
|
# Should be decreasing
|
||||||
|
for i in range(len(rewind_progress) - 1):
|
||||||
|
assert rewind_progress[i] > rewind_progress[i + 1]
|
||||||
Reference in New Issue
Block a user