mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c78023dae7 | |||
| 36d0ba5127 | |||
| dca792951e | |||
| 0a369e104a | |||
| b0cdf99957 | |||
| 733f9768b5 | |||
| 7fe49f9e54 | |||
| e1afb96474 | |||
| f395f36dec | |||
| 738ba9272f | |||
| 2a0495f8c3 | |||
| c3c9c2b089 | |||
| e13c6a6110 | |||
| 140cf2a420 | |||
| c092194cf2 | |||
| b858ba1b6c | |||
| e870af119f | |||
| 4174c3b303 |
@@ -105,7 +105,7 @@ lerobot-train \
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@
|
||||
- sections:
|
||||
- local: il_robots
|
||||
title: Imitation Learning for Robots
|
||||
- local: lelab
|
||||
title: LeLab - Lerobot GUI
|
||||
- local: bring_your_own_policies
|
||||
title: Adding a Policy
|
||||
- local: integrate_hardware
|
||||
@@ -63,12 +61,10 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: molmoact2
|
||||
title: MolmoAct2
|
||||
- local: vla_jepa
|
||||
title: VLA-JEPA
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
@@ -79,8 +75,6 @@
|
||||
- sections:
|
||||
- local: sarm
|
||||
title: SARM
|
||||
- local: robometer
|
||||
title: ROBOMETER
|
||||
- local: topreward
|
||||
title: TOPReward
|
||||
title: "Reward Models"
|
||||
|
||||
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
|
||||
|
||||
- [SmolVLA](./smolvla)
|
||||
- [Pi0.5](./pi05)
|
||||
- [GR00T N1.7](./groot)
|
||||
- [GR00T N1.5](./groot)
|
||||
|
||||
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
|
||||
|
||||
|
||||
+30
-79
@@ -1,19 +1,16 @@
|
||||
# GR00T Policy
|
||||
# GR00T N1.5 Policy
|
||||
|
||||
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
|
||||
## Model Overview
|
||||
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
|
||||
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
|
||||
@@ -31,46 +28,33 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
3. Install LeRobot by running:
|
||||
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
pip install lerobot[groot]
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T N1.7:
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
```python
|
||||
policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -103,54 +87,21 @@ accelerate launch \
|
||||
|
||||
## Performance Results
|
||||
|
||||
### LIBERO Benchmark Results
|
||||
### Libero Benchmark Results
|
||||
|
||||
> [!NOTE]
|
||||
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
|
||||
> Follow our instructions for Libero usage: [Libero](./libero)
|
||||
|
||||
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
|
||||
### GR00T N1.7 LIBERO Checkpoints
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
|
||||
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
|
||||
|
||||
| Suite | Checkpoint subdirectory |
|
||||
| -------------- | ----------------------- |
|
||||
| LIBERO Spatial | `libero_spatial` |
|
||||
| LIBERO Object | `libero_object` |
|
||||
| LIBERO Goal | `libero_goal` |
|
||||
| LIBERO 10 | `libero_10` |
|
||||
|
||||
Preliminary LeRobot integration results:
|
||||
|
||||
| Suite | Status | Success rate | n_episodes |
|
||||
| -------------- | ------ | -----------: | ---------: |
|
||||
| LIBERO Spatial | ✓ | ~95% | XX |
|
||||
| LIBERO Object | ✓ | XX% | XX |
|
||||
| LIBERO Goal | ✓ | XX% | XX |
|
||||
| LIBERO 10 | ✓ | XX% | XX |
|
||||
| **Average** | ✓ | **XX%** | **XX** |
|
||||
|
||||
Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.n_episodes=50
|
||||
```
|
||||
|
||||
Use `eval.n_episodes >= 50` per suite when reporting success rates.
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
@@ -180,4 +131,4 @@ lerobot-rollout\
|
||||
|
||||
## License
|
||||
|
||||
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
|
||||
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# LeLab - LeRobot Guide
|
||||
|
||||
LeLab is a graphical user interface built on top of the LeRobot library, designed to make robotics accessible without needing to memorize CLI commands. From a single app you can configure your robot, teleoperate it, collect datasets, train policies locally or on cloud GPUs via HF Jobs, and deploy trained models back onto your robot. It's the easiest way to go from an unboxed SO-101 to a working policy, and a great companion for anyone learning the LeRobot workflow. Source code and issues live on GitHub: [huggingface/leLab](https://github.com/huggingface/leLab).
|
||||
|
||||
> [!TIP]
|
||||
> For now LeLab is compatible only with SO-ARM101
|
||||
|
||||
<Youtube id="VqyKUuW9V1g" />
|
||||
|
||||
### Installation
|
||||
|
||||
Requires [`uv`](https://docs.astral.sh/uv/getting-started/installation/). Install and launch in one command:
|
||||
|
||||
```
|
||||
uv tool install git+https://github.com/huggingface/leLab.git && lelab
|
||||
```
|
||||
|
||||
After install, run `lelab` from your terminal anytime to start the app.
|
||||
|
||||
### Features
|
||||
|
||||
- **Add robots** — Select arm type (leader/follower), calibrate each joint from the middle position, and attach cameras.
|
||||
- **Teleoperation** — Control the follower arm with the leader and see a live 3D visualization of the arms.
|
||||
- **Dataset recording** — Define a task description, number of episodes, and episode/reset durations. Press spacebar to advance between episodes. 30+ episodes recommended.
|
||||
- **Local training** — Train a policy directly on your own machine with a selected dataset, policy type, batch size, and step count.
|
||||
- **Cloud training with HF Jobs** — Train on powerful GPUs via [HF Jobs](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) with transparent pricing. Run `hf auth login` first. See the [Compute HW Guide](hardware_guide) for hardware/batch size tips.
|
||||
- **Training visualization** — Watch progress live in the app, with checkpoints saved automatically.
|
||||
- **Run trained policies** — Pick any model from your jobs list and run inference on your robot with one click.
|
||||
- **Use community datasets** — Provide any Hugging Face dataset ID to train on datasets you didn't record yourself.
|
||||
@@ -275,7 +275,7 @@ A converter aggregates per‑episode files into larger shards and writes episode
|
||||
pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip"
|
||||
|
||||
# Convert an existing v2.1 dataset hosted on the Hub:
|
||||
python -m lerobot.scripts.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id=<HF_USER/DATASET_ID>
|
||||
```
|
||||
|
||||
**What it does**
|
||||
|
||||
@@ -238,7 +238,7 @@ your dataset has not been converted with quantile statistics, you can add them
|
||||
with:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ lerobot-train \
|
||||
If your dataset is not converted with `quantiles`, you can convert it with the following command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/augment_dataset_quantile_stats.py \
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=your_dataset \
|
||||
```
|
||||
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
## Research Paper
|
||||
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -31,103 +24,4 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Models:
|
||||
|
||||
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
|
||||
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
|
||||
is itself a defaulted dataclass, so the original config dataclasses fail to import
|
||||
(`non-default argument follows default argument`). The two implementations therefore
|
||||
**cannot be imported in the same Python process**.
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
```bash
|
||||
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
|
||||
CKPT=$(python - <<'PY'
|
||||
import os
|
||||
from huggingface_hub import snapshot_download
|
||||
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
|
||||
allow_patterns=["libero_10/*"]), "libero_10"))
|
||||
PY
|
||||
)
|
||||
|
||||
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
|
||||
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
|
||||
tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
|
||||
|
||||
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
|
||||
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
At inference time only the Qwen backbone and action head are used; the world model is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -300,7 +300,7 @@ This replaces the old episode-per-file structure with efficient, optimally-sized
|
||||
If you have existing datasets in v2.1 format, use the migration tool:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/convert_dataset_v21_to_v30.py \
|
||||
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id your_id/existing_dataset
|
||||
```
|
||||
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
# ROBOMETER
|
||||
|
||||
ROBOMETER is a **general-purpose video-language robotic reward model**. It predicts dense, frame-level task progress and frame-level success from a trajectory video and a task description.
|
||||
|
||||
**Paper**: [ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons](https://arxiv.org/abs/2603.02115)
|
||||
**Project**: [robometer.github.io](https://robometer.github.io/)
|
||||
**Original code**: [github.com/robometer/robometer](https://github.com/robometer/robometer)
|
||||
**Checkpoint**: [lerobot/Robometer-4B](https://huggingface.co/lerobot/Robometer-4B)
|
||||
|
||||
## Overview
|
||||
|
||||
ROBOMETER builds on `Qwen/Qwen3-VL-4B-Instruct` and adds three lightweight prediction heads:
|
||||
|
||||
- **Progress head**: predicts per-frame task progress in `[0, 1]`.
|
||||
- **Success head**: predicts per-frame task success probability.
|
||||
- **Preference head**: predicts which of two trajectories better completes the task during training.
|
||||
|
||||
The paper trains ROBOMETER with a composite objective:
|
||||
|
||||
```text
|
||||
L = L_pref + L_prog + L_succ
|
||||
```
|
||||
|
||||
The LeRobot integration is currently **inference-only**. It preserves the preference head so that the published `Robometer-4B` checkpoint loads without remapping, but `compute_reward()` queries the progress or success head only.
|
||||
|
||||
## What the LeRobot Integration Covers
|
||||
|
||||
- Standard `reward_model.type=robometer` configuration through LeRobot.
|
||||
- Qwen3-VL image and text preprocessing through `RobometerEncoderProcessorStep`.
|
||||
- LeRobot reward-model save/load APIs through `PreTrainedRewardModel`.
|
||||
- Dense, frame-level progress and success predictions internally.
|
||||
- A scalar reward through `compute_reward()` for downstream LeRobot reward-model usage.
|
||||
|
||||
This page focuses on using the published ROBOMETER checkpoint as a zero-shot reward model. Training ROBOMETER from scratch is outside the current LeRobot integration.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install the ROBOMETER dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[robometer]"
|
||||
```
|
||||
|
||||
If you use `uv` directly from a source checkout:
|
||||
|
||||
```bash
|
||||
uv sync --extra robometer
|
||||
```
|
||||
|
||||
ROBOMETER uses a Qwen3-VL-4B backbone, so GPU inference is strongly recommended.
|
||||
|
||||
## Model Inputs and Outputs
|
||||
|
||||
ROBOMETER expects:
|
||||
|
||||
- A trajectory video or sequence of frames.
|
||||
- A natural-language task description.
|
||||
|
||||
In LeRobot datasets, the preprocessor reads:
|
||||
|
||||
| Config field | Default | Meaning |
|
||||
| ------------------------- | ------------------------ | ----------------------------------------------------- |
|
||||
| `reward_model.image_key` | `observation.images.top` | Camera/video observation used by ROBOMETER |
|
||||
| `reward_model.task_key` | `task` | Key in complementary data that stores the task string |
|
||||
| `reward_model.max_frames` | `8` | Maximum number of frames passed to ROBOMETER |
|
||||
|
||||
The model predicts per-frame progress and success internally. The LeRobot reward API returns a scalar per sample:
|
||||
|
||||
- `reward_output="progress"` (default): return the last-frame progress, clamped to `[0, 1]`.
|
||||
- `reward_output="success"`: return `1.0` if the last-frame success probability is above `success_threshold`, otherwise `0.0`.
|
||||
|
||||
## Usage
|
||||
|
||||
### Load the Reward Model Directly
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||
|
||||
cfg = RobometerConfig(
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
reward_output="progress",
|
||||
)
|
||||
reward_model = RobometerRewardModel.from_pretrained(cfg.pretrained_path, config=cfg)
|
||||
```
|
||||
|
||||
### Encode Frames and Compute a Reward
|
||||
|
||||
For a direct Python call, provide frames as `uint8` arrays with shape `(T, H, W, C)` and a task string:
|
||||
|
||||
```python
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
|
||||
# frames: np.ndarray, shape (T, H, W, C), dtype uint8
|
||||
# task: str
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=cfg.base_model_id,
|
||||
use_multi_image=cfg.use_multi_image,
|
||||
use_per_frame_progress_token=cfg.use_per_frame_progress_token,
|
||||
max_frames=cfg.max_frames,
|
||||
)
|
||||
|
||||
encoded = encoder.encode_samples([(frames, task)])
|
||||
batch = {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in encoded.items()}
|
||||
|
||||
reward = reward_model.compute_reward(batch)
|
||||
```
|
||||
|
||||
`reward` is a tensor of shape `(batch_size,)`.
|
||||
|
||||
### Use the Reward Factory
|
||||
|
||||
You can also instantiate ROBOMETER through the reward factory:
|
||||
|
||||
```python
|
||||
from lerobot.rewards import make_reward_model, make_reward_model_config, make_reward_pre_post_processors
|
||||
|
||||
cfg = make_reward_model_config(
|
||||
"robometer",
|
||||
pretrained_path="lerobot/Robometer-4B",
|
||||
device="cuda",
|
||||
image_key="observation.images.top",
|
||||
)
|
||||
reward_model = make_reward_model(cfg)
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(cfg)
|
||||
```
|
||||
|
||||
The preprocessor writes Qwen-VL tensors under the `observation.robometer.*` namespace, and `compute_reward()` reads those encoded tensors.
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Backbone and Vocabulary
|
||||
|
||||
The published checkpoint uses a Qwen3-VL-4B backbone. ROBOMETER adds five special tokens to the tokenizer in a fixed order:
|
||||
|
||||
```text
|
||||
<|split_token|>
|
||||
<|reward_token|>
|
||||
<|pref_token|>
|
||||
<|sim_token|>
|
||||
<|prog_token|>
|
||||
```
|
||||
|
||||
`<|prog_token|>` is inserted after each frame and is the hidden-state position used for per-frame progress and success prediction. `<|split_token|>` and `<|pref_token|>` are used by the paper's pairwise trajectory preference objective. `<|reward_token|>` and `<|sim_token|>` are preserved for checkpoint compatibility.
|
||||
|
||||
The LeRobot config stores a serialized `vlm_config` with the post-resize vocabulary so the model can reload from `config.json` without downloading the base Qwen weights first. For `Qwen/Qwen3-VL-4B-Instruct`, the tokenizer length is `151669`, and the five ROBOMETER tokens produce the checkpoint vocabulary size `151674`.
|
||||
|
||||
### Progress Prediction
|
||||
|
||||
In the published checkpoint, progress is discrete. The progress head outputs logits over `progress_discrete_bins=10` uniformly spaced bin centers in `[0, 1]`. LeRobot converts these logits into a continuous value by applying a softmax and taking the expectation over bin centers, matching the upstream ROBOMETER implementation.
|
||||
|
||||
### Success Prediction
|
||||
|
||||
The success head outputs raw logits per frame. LeRobot converts them to probabilities with `sigmoid`. When `reward_output="success"`, `compute_reward()` thresholds the last-frame success probability using `success_threshold`.
|
||||
|
||||
## Limitations
|
||||
|
||||
- The current LeRobot integration is inference-only; it does not implement ROBOMETER training or preference-pair training.
|
||||
- `compute_reward()` returns a scalar per sample for the LeRobot reward-model API, even though ROBOMETER predicts per-frame progress and success internally.
|
||||
- ROBOMETER is video-language based; it does not use privileged robot state such as contact forces or object poses.
|
||||
|
||||
## References
|
||||
|
||||
- [ROBOMETER project](https://robometer.github.io/)
|
||||
- [ROBOMETER paper](https://arxiv.org/abs/2603.02115)
|
||||
- [Original ROBOMETER code](https://github.com/robometer/robometer)
|
||||
- [Published ROBOMETER-4B checkpoint](https://huggingface.co/lerobot/Robometer-4B)
|
||||
- [Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{liang2026robometer,
|
||||
title = {Robometer: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons},
|
||||
author={Anthony Liang and Yigit Korkmaz and Jiahui Zhang and Minyoung Hwang and Abrar Anwar and Sidhant Kaushik and Aditya Shah and Alex S. Huang and Luke Zettlemoyer and Dieter Fox and Yu Xiang and Anqi Li and Andreea Bobu and Abhishek Gupta and Stephen Tu and Erdem Biyik and Jesse Zhang},
|
||||
year={2026},
|
||||
booktitle={Robotics: Science and Systems 2026},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream ROBOMETER code and model pages for the licenses of the original implementation and released checkpoints.
|
||||
@@ -1,235 +0,0 @@
|
||||
# VLA-JEPA
|
||||
|
||||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
VLA-JEPA has three main components:
|
||||
|
||||
| Component | Module | Role |
|
||||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||
|
||||
### Data flow
|
||||
|
||||
**Training:**
|
||||
|
||||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||
|
||||
**Inference:**
|
||||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||
|
||||
### Action head details
|
||||
|
||||
Available presets via `action_model_type`:
|
||||
|
||||
| Preset | Hidden dim | Heads | Head dim |
|
||||
| ------- | ---------- | ----- | -------- |
|
||||
| `DiT-B` | 768 | 12 | 64 |
|
||||
| `DiT-L` | 1536 | 32 | 48 |
|
||||
|
||||
### World model details
|
||||
|
||||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||
|
||||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||
|
||||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||
|
||||
---
|
||||
|
||||
## Pretrained Checkpoints
|
||||
|
||||
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||
|
||||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
|
||||
|
||||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
Key parameters in `VLAJEPAConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
|
||||
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
|
||||
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
|
||||
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
|
||||
|
||||
---
|
||||
|
||||
## Training
|
||||
|
||||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||
|
||||
### Full training from scratch
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
policy.type=vla_jepa \
|
||||
policy.repo_id=your_org/your_repo \
|
||||
dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning from a pretrained checkpoint
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
### Fine-tuning on a different embodiment
|
||||
|
||||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||
|
||||
The layers that depend on `action_dim` and `state_dim` are:
|
||||
|
||||
| Layer | Key prefix |
|
||||
| ----------------------------------------- | ----------------------------------- |
|
||||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--policy.freeze_qwen=true \
|
||||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||
--dataset.repo_id=your_org/your_dataset
|
||||
```
|
||||
|
||||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||
|
||||
### Reproducing the LIBERO results
|
||||
|
||||
**Training on LIBERO:**
|
||||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--steps=30000
|
||||
```
|
||||
|
||||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
To evaluate a subset of tasks only:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.task_ids='[0,1,2]' \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=5
|
||||
```
|
||||
|
||||
**Expected results:**
|
||||
|
||||
| Suite | Episodes | Successes | Success Rate |
|
||||
| -------------- | -------- | --------- | ------------ |
|
||||
| libero_spatial | 100 | 93 | **95.0%** |
|
||||
| libero_object | 100 | 100 | **100.0%** |
|
||||
| libero_goal | 100 | 98 | **98.0%** |
|
||||
| libero_10 | 100 | 96 | **93.0%** |
|
||||
| **Overall** | **400** | **387** | **96.5%** |
|
||||
|
||||
---
|
||||
|
||||
## Fine-tuning on datasets with a different number of cameras
|
||||
|
||||
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
|
||||
|
||||
**Default behaviour — view padding / trimming (no action required)**
|
||||
|
||||
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
|
||||
|
||||
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
|
||||
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
|
||||
|
||||
**Option 1 — Disable the world model**
|
||||
|
||||
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||
--policy.enable_world_model=false \
|
||||
--policy.repo_id=your_org/your_repo \
|
||||
--dataset.repo_id=your_org/single_camera_dataset
|
||||
```
|
||||
|
||||
**Option 2 — Reinitialize the predictor input projection**
|
||||
|
||||
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||
year = {2026},
|
||||
eprint = {2602.10098},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2602.10098},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||
@@ -212,12 +212,10 @@ groot = [
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
@@ -283,7 +281,6 @@ all = [
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
@@ -294,7 +291,6 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
|
||||
@@ -41,8 +41,8 @@ class DatasetRecordConfig:
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# If True, upload as private; if None, defer to the org default on the Hub (only affects orgs).
|
||||
private: bool | None = None
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
|
||||
@@ -177,12 +177,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
|
||||
@@ -524,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool | None = None,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
**card_kwargs,
|
||||
@@ -543,8 +543,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
tag_version: If ``True``, create a Git tag for the current codebase
|
||||
version.
|
||||
push_videos: If ``False``, skip uploading the ``videos/`` directory.
|
||||
private: If ``True``, create a private repository. If ``None``
|
||||
(default), defer to the org default on the Hub (only affects orgs).
|
||||
private: If ``True``, create a private repository.
|
||||
allow_patterns: Glob pattern(s) restricting which files to upload.
|
||||
upload_large_folder: If ``True``, use ``upload_large_folder`` instead
|
||||
of ``upload_folder`` for very large datasets.
|
||||
|
||||
@@ -57,7 +57,6 @@ from .pretrained import PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
@@ -158,10 +157,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
|
||||
return MolmoAct2Policy
|
||||
elif name == "vla_jepa":
|
||||
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
|
||||
return VLAJEPAPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -216,8 +211,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "molmoact2":
|
||||
return MolmoAct2Config(**kwargs)
|
||||
elif policy_type == "vla_jepa":
|
||||
return VLAJEPAConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -280,22 +273,26 @@ def make_pre_post_processors(
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
|
||||
return make_groot_pre_post_processors_from_pretrained(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
postprocessor_config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
)
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -418,7 +415,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
@@ -436,14 +432,6 @@ def make_pre_post_processors(
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
processors = make_vla_jepa_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
|
||||
@@ -18,12 +18,4 @@ from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"GR00TN17", "GR00TN17Config"}:
|
||||
from .groot_n1_7 import GR00TN17, GR00TN17Config
|
||||
|
||||
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -43,9 +42,6 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -185,7 +181,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
@@ -269,8 +266,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -321,71 +318,6 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class AlternateVLDiT(DiT):
|
||||
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
|
||||
|
||||
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attend_text_every_n_blocks = attend_text_every_n_blocks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
image_mask: torch.Tensor | None = None,
|
||||
backbone_attention_mask: torch.Tensor | None = None,
|
||||
):
|
||||
if image_mask is None:
|
||||
raise ValueError("image_mask is required for AlternateVLDiT.")
|
||||
if backbone_attention_mask is None:
|
||||
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
|
||||
|
||||
temb = self.timestep_encoder(timestep)
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
image_attention_mask = image_mask & backbone_attention_mask
|
||||
non_image_attention_mask = (~image_mask) & backbone_attention_mask
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
if not self.config.interleave_self_attention:
|
||||
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
curr_encoder_attention_mask = (
|
||||
non_image_attention_mask
|
||||
if idx % (2 * self.attend_text_every_n_blocks) == 0
|
||||
else image_attention_mask
|
||||
)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=curr_encoder_attention_mask,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@@ -430,8 +362,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from .action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
if self._beta_dist is None:
|
||||
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -14,327 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
"n1_7": GROOT_N1_7,
|
||||
"n1d7": GROOT_N1_7,
|
||||
"n17": GROOT_N1_7,
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
}
|
||||
|
||||
|
||||
def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
|
||||
if transform is None:
|
||||
return None
|
||||
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
|
||||
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
|
||||
supported = ", ".join(
|
||||
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
|
||||
f"Supported transforms: none, {supported}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_groot_model_version(model_path: str | None) -> str | None:
|
||||
if not model_path:
|
||||
return None
|
||||
model_path_lower = model_path.lower()
|
||||
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
|
||||
return GROOT_N1_7
|
||||
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
|
||||
# N1.5 is unsupported, but it must still be recognised here to fail loudly
|
||||
# rather than silently treating an N1.5 checkpoint as N1.7.
|
||||
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
|
||||
return GROOT_N1_5
|
||||
config_version = _infer_groot_model_version_from_local_config(model_path)
|
||||
if config_version is not None:
|
||||
return config_version
|
||||
return None
|
||||
|
||||
|
||||
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
if model_path is None:
|
||||
return False
|
||||
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
if "libero_sim" in modality_configs:
|
||||
return "libero_sim"
|
||||
if len(modality_configs) == 1:
|
||||
return next(iter(modality_configs))
|
||||
return None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
return None
|
||||
modality_configs = processor_kwargs.get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag is None:
|
||||
return None
|
||||
|
||||
embodiment_config = modality_configs.get(embodiment_tag, {})
|
||||
if not isinstance(embodiment_config, dict):
|
||||
return None
|
||||
action_config = embodiment_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return None
|
||||
delta_indices = action_config.get("delta_indices", [])
|
||||
if not isinstance(delta_indices, list):
|
||||
return None
|
||||
return len(delta_indices) or None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_execution_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
|
||||
if action_horizon is None:
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag == "libero_sim":
|
||||
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
|
||||
# actions. Keeping that execution cadence avoids stale open-loop chunks.
|
||||
return min(action_horizon, 8)
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, str):
|
||||
continue
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
@@ -343,44 +28,35 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 132
|
||||
max_state_dim: int = 64
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 132
|
||||
max_action_dim: int = 32
|
||||
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Groot-specific model parameters
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str | None = None
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -420,16 +96,17 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -440,66 +117,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
|
||||
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
|
||||
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
|
||||
# wrapper applies this conversion; mirror it here so eval on the
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
@@ -575,10 +192,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -0,0 +1,503 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
@@ -0,0 +1,396 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
@@ -0,0 +1,541 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
if isinstance(image_inputs["pixel_values"], list):
|
||||
_pv = image_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
image_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
if isinstance(video_inputs["pixel_values"], list):
|
||||
_pv = video_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
video_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
|
||||
def _to_tensor(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
if v and isinstance(v[0], list):
|
||||
v = [t for sub in v for t in sub]
|
||||
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||
return torch.as_tensor(v)
|
||||
|
||||
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
@@ -0,0 +1,380 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
from .action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from .utils import ensure_eagle_cache_ready
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
self.post_init()
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if ACTION in inputs:
|
||||
action = inputs[ACTION]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
@@ -1,966 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
"model_name": "nvidia/Cosmos-Reason2-2B",
|
||||
"backbone_model_type": "qwen",
|
||||
"model_revision": None,
|
||||
"tune_top_llm_layers": 0,
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||
"shortest_image_edge": None,
|
||||
"crop_fraction": None,
|
||||
"random_rotation_angle": None,
|
||||
"color_jitter_params": None,
|
||||
"use_albumentations_transforms": True,
|
||||
"extra_augmentation_config": None,
|
||||
"formalize_language": True,
|
||||
"apply_sincos_state_encoding": False,
|
||||
"use_percentiles": True,
|
||||
"use_relative_action": False,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
"action_horizon": 40,
|
||||
"hidden_size": 1024,
|
||||
"input_embedding_dim": 1536,
|
||||
"state_history_length": 1,
|
||||
"add_pos_embed": True,
|
||||
"attn_dropout": 0.2,
|
||||
"use_vlln": True,
|
||||
"max_seq_len": 1024,
|
||||
"use_alternate_vl_dit": True,
|
||||
"attend_text_every_n_blocks": 2,
|
||||
"diffusion_model_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 48,
|
||||
"norm_type": "ada_norm",
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
"output_dim": 1024,
|
||||
"interleave_self_attention": True,
|
||||
},
|
||||
"vl_self_attention_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
},
|
||||
"num_inference_timesteps": 4,
|
||||
"noise_beta_alpha": 1.5,
|
||||
"noise_beta_beta": 1.0,
|
||||
"noise_s": 0.999,
|
||||
"num_timestep_buckets": 1000,
|
||||
"tune_projector": True,
|
||||
"tune_diffusion_model": True,
|
||||
"tune_vlln": True,
|
||||
"state_dropout_prob": 0.2,
|
||||
"exclude_state": False,
|
||||
"use_mean_std": False,
|
||||
"max_num_embodiments": 32,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
|
||||
|
||||
class GR00TN17Config(PretrainedConfig):
|
||||
"""Configuration for NVIDIA GR00T N1.7.
|
||||
|
||||
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
|
||||
flow-matching action head. This mirrors the public N1.7 checkpoint config
|
||||
while keeping it local to LeRobot and independent from the external
|
||||
Isaac-GR00T ``gr00t`` Python package.
|
||||
"""
|
||||
|
||||
model_type = "Gr00tN1d7"
|
||||
|
||||
_defaults = GR00T_N1_7_DEFAULTS
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
"""Two-layer MLP with category-specific weights."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
|
||||
|
||||
The frequency scalar is intentionally created on CPU and then broadcast with
|
||||
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
|
||||
embedding and avoids tiny dtype/device construction differences in parity
|
||||
tests.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
"""Action encoder with category-specific projections and sinusoidal time encoding."""
|
||||
|
||||
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
|
||||
super().__init__()
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, horizon, _ = actions.shape
|
||||
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,).")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
|
||||
action_emb = self.W1(actions, cat_ids)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
|
||||
return self.W3(x, cat_ids)
|
||||
|
||||
|
||||
class Qwen3Backbone(nn.Module):
|
||||
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
|
||||
|
||||
The public checkpoint stores the action head in the GR00T checkpoint but
|
||||
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
|
||||
keeps the nested HF module layout compatible across transformer versions
|
||||
and exposes the hidden states consumed by the action head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/Cosmos-Reason2-2B",
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
tune_top_llm_layers: int = 0,
|
||||
trainable_params_fp32: bool = False,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if use_flash_attention:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
|
||||
extra_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
except ImportError:
|
||||
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
|
||||
extra_kwargs["attn_implementation"] = "sdpa"
|
||||
if load_bf16:
|
||||
extra_kwargs["torch_dtype"] = torch.bfloat16
|
||||
|
||||
if load_pretrained_weights:
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
**extra_kwargs,
|
||||
**transformers_loading_kwargs,
|
||||
).eval()
|
||||
else:
|
||||
self.model = self._from_backbone_config(
|
||||
model_name=model_name,
|
||||
model_kwargs=extra_kwargs,
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
|
||||
if load_bf16 and trainable_params_fp32:
|
||||
for parameter in self.parameters():
|
||||
if parameter.requires_grad:
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
|
||||
) -> None:
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.visual.requires_grad_(False)
|
||||
if tune_top_llm_layers > 0:
|
||||
for layer in self.language_model.layers[-tune_top_llm_layers:]:
|
||||
for parameter in layer.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if self.language_model and not self.tune_llm:
|
||||
self.language_model.eval()
|
||||
if self.visual and not self.tune_visual:
|
||||
self.visual.eval()
|
||||
|
||||
@property
|
||||
def language_model(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).language_model
|
||||
|
||||
@property
|
||||
def visual(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).visual
|
||||
|
||||
def _from_backbone_config(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
config_kwargs: dict[str, Any],
|
||||
) -> nn.Module:
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
if "mm_token_type_ids" in model_input:
|
||||
return
|
||||
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
|
||||
return
|
||||
|
||||
input_ids = model_input.get("input_ids")
|
||||
if input_ids is None:
|
||||
return
|
||||
|
||||
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[input_ids == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[input_ids == video_token_id] = 2
|
||||
|
||||
model_input["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
|
||||
|
||||
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
|
||||
drops text position ids before calling text-layer flash attention. GR00T
|
||||
N1.7 was aligned against the older Transformers path, where a fourth text
|
||||
position row is forwarded alongside the temporal/height/width rows. Adding
|
||||
the row here preserves the newer multimodal position computation while
|
||||
keeping flash attention on the legacy code path.
|
||||
"""
|
||||
|
||||
if "position_ids" in model_input:
|
||||
return
|
||||
|
||||
qwen3_model = getattr(self.model, "model", self.model)
|
||||
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
|
||||
if compute_3d_position_ids is None:
|
||||
return
|
||||
|
||||
position_ids = compute_3d_position_ids(
|
||||
input_ids=model_input.get("input_ids"),
|
||||
image_grid_thw=model_input.get("image_grid_thw"),
|
||||
video_grid_thw=model_input.get("video_grid_thw"),
|
||||
inputs_embeds=None,
|
||||
attention_mask=model_input.get("attention_mask"),
|
||||
past_key_values=None,
|
||||
mm_token_type_ids=model_input.get("mm_token_type_ids"),
|
||||
)
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
|
||||
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
|
||||
|
||||
model_input["position_ids"] = position_ids
|
||||
|
||||
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
|
||||
|
||||
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
|
||||
Newer releases expose the post-final-norm tensor there instead. Capturing
|
||||
the last decoder layer output directly keeps the N1.7 action head input
|
||||
stable across Transformers versions.
|
||||
"""
|
||||
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
|
||||
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
|
||||
if isinstance(output, torch.Tensor):
|
||||
captured["features"] = output
|
||||
elif isinstance(output, (tuple, list)) and output:
|
||||
captured["features"] = output[0]
|
||||
elif hasattr(output, "last_hidden_state"):
|
||||
captured["features"] = output.last_hidden_state
|
||||
|
||||
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
|
||||
try:
|
||||
outputs = self.model(**model_input, output_hidden_states=True)
|
||||
finally:
|
||||
hook.remove()
|
||||
|
||||
return captured.get("features", outputs.hidden_states[-1])
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
|
||||
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
|
||||
model_input = {key: vl_input[key] for key in keys_to_use}
|
||||
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
|
||||
self._ensure_mm_token_type_ids(model_input)
|
||||
self._ensure_legacy_qwen3_position_ids(model_input)
|
||||
features = self._last_decoder_layer_output(model_input)
|
||||
image_mask = model_input["input_ids"] == self.model.config.image_token_id
|
||||
attention_mask = model_input["attention_mask"] == 1
|
||||
return BatchFeature(
|
||||
data={
|
||||
"backbone_features": features,
|
||||
"backbone_attention_mask": attention_mask,
|
||||
"image_mask": image_mask,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GR00TN17ActionHead(nn.Module):
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: GR00TN17Config):
|
||||
require_package("diffusers", extra="groot")
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
if config.use_alternate_vl_dit:
|
||||
self.model = AlternateVLDiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
|
||||
)
|
||||
else:
|
||||
self.model = DiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
)
|
||||
|
||||
self.action_dim = config.max_action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim * config.state_history_length,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
|
||||
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
|
||||
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
|
||||
else:
|
||||
self.vl_self_attention = nn.Identity()
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
self.state_dropout_prob = config.state_dropout_prob
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
|
||||
) -> None:
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vlln = tune_vlln
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vlln:
|
||||
self.vlln.requires_grad_(False)
|
||||
self.vl_self_attention.requires_grad_(False)
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vlln:
|
||||
self.vlln.eval()
|
||||
self.vl_self_attention.eval()
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if self._beta_dist is None:
|
||||
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
|
||||
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
|
||||
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = self.vlln(backbone_output["backbone_features"])
|
||||
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
vl_embeds = backbone_output.backbone_features
|
||||
device = vl_embeds.device
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
if action_input.state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = action_input.state.view(action_input.state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, embodiment_id)
|
||||
|
||||
if self.training and self.state_dropout_prob > 0:
|
||||
do_dropout = (
|
||||
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
|
||||
)
|
||||
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
|
||||
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None]
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
data={
|
||||
"loss": loss,
|
||||
"action_loss": action_loss,
|
||||
"action_mask": action_mask,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
state = action_input.state
|
||||
if state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = state.view(state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, action_input.embodiment_id)
|
||||
return BatchFeature(
|
||||
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action_with_features(
|
||||
self,
|
||||
backbone_features: torch.Tensor,
|
||||
state_features: torch.Tensor,
|
||||
embodiment_id: torch.Tensor,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
vl_embeds = backbone_features
|
||||
batch_size = vl_embeds.shape[0]
|
||||
device = vl_embeds.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.action_dim),
|
||||
dtype=vl_embeds.dtype,
|
||||
device=device,
|
||||
)
|
||||
dt = 1.0 / self.num_inference_timesteps
|
||||
vel_strength = torch.ones_like(actions)
|
||||
|
||||
if "action" in action_input:
|
||||
if options is None:
|
||||
raise ValueError("RTC options are required when action is provided to get_action.")
|
||||
action_horizon_before_padding = options["action_horizon"]
|
||||
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
|
||||
:,
|
||||
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
|
||||
:,
|
||||
]
|
||||
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
|
||||
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
|
||||
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
|
||||
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
|
||||
ramp = ramp / ramp[-1].clamp_min(1e-8)
|
||||
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
|
||||
None, :, None
|
||||
].to(device)
|
||||
|
||||
for t_step in range(self.num_inference_timesteps):
|
||||
t_cont = t_step / float(self.num_inference_timesteps)
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"action_pred": actions,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(
|
||||
self,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
features = self._encode_features(backbone_output, action_input)
|
||||
return self.get_action_with_features(
|
||||
backbone_features=features.backbone_features,
|
||||
state_features=features.state_features,
|
||||
embodiment_id=action_input.embodiment_id,
|
||||
backbone_output=backbone_output,
|
||||
action_input=action_input,
|
||||
options=options,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
|
||||
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=True,
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"tie_word_embeddings": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [5, 11, 17],
|
||||
"depth": 24,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1024,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 2048,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
|
||||
class GR00TN17(PreTrainedModel):
|
||||
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
|
||||
|
||||
config_class = GR00TN17Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN17Config,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
backbone_cls = get_backbone_cls(config)
|
||||
self.backbone = backbone_cls(
|
||||
model_name=config.model_name,
|
||||
tune_llm=config.tune_llm,
|
||||
tune_visual=config.tune_visual,
|
||||
select_layer=config.select_layer,
|
||||
reproject_vision=config.reproject_vision,
|
||||
use_flash_attention=config.use_flash_attention,
|
||||
load_bf16=config.load_bf16,
|
||||
tune_top_llm_layers=config.tune_top_llm_layers,
|
||||
trainable_params_fp32=config.backbone_trainable_params_fp32,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_pretrained_weights=load_backbone_weights,
|
||||
)
|
||||
self.action_head = GR00TN17ActionHead(config)
|
||||
self.post_init()
|
||||
|
||||
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
|
||||
global tree
|
||||
if tree is None:
|
||||
require_package("dm-tree", extra="groot", import_name="tree")
|
||||
tree = importlib.import_module("tree")
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_dtype(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
return x.to(self.device, dtype=self.dtype)
|
||||
return x.to(self.device)
|
||||
|
||||
return (
|
||||
tree.map_structure(to_device_with_dtype, backbone_inputs),
|
||||
tree.map_structure(to_device_with_dtype, action_inputs),
|
||||
)
|
||||
|
||||
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head(backbone_outputs, action_inputs)
|
||||
|
||||
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head.get_action(backbone_outputs, action_inputs, options)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
tune_vlln = kwargs.pop("tune_vlln", True)
|
||||
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
try:
|
||||
local_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
repo_type="model",
|
||||
revision=kwargs.get("revision"),
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
token=kwargs.get("token"),
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_backbone_weights=load_backbone_weights,
|
||||
**kwargs,
|
||||
)
|
||||
pretrained_model.backbone.set_trainable_parameters(
|
||||
tune_visual=tune_visual,
|
||||
tune_llm=tune_llm,
|
||||
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
|
||||
)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector,
|
||||
tune_diffusion_model=tune_diffusion_model,
|
||||
tune_vlln=tune_vlln,
|
||||
)
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
|
||||
try:
|
||||
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
@@ -17,13 +17,22 @@
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -37,19 +46,8 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .configuration_groot import GrootConfig
|
||||
from .groot_n1 import GR00TN15
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
@@ -69,35 +67,37 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
self._action_queue_steps = self._resolve_action_queue_steps()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -118,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T N1.7 models - loads the raw model
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
@@ -141,15 +141,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -180,7 +174,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -196,15 +190,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
@@ -225,16 +215,6 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -245,160 +225,21 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
execution_horizon = infer_groot_n1_7_action_execution_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
horizons = [n_action_steps]
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
if execution_horizon is not None:
|
||||
horizons.append(execution_horizon)
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
|
||||
horizons = [actions.shape[1]]
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
|
||||
for horizon in (self.config.chunk_size, self.config.n_action_steps):
|
||||
horizon = int(horizon)
|
||||
if horizon > 0:
|
||||
horizons.append(horizon)
|
||||
|
||||
return max(1, min(horizons))
|
||||
|
||||
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
if include_action:
|
||||
allowed_base.update({"action", "action_mask"})
|
||||
|
||||
allowed_base.update(
|
||||
{
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
}
|
||||
)
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
self,
|
||||
inputs: dict[str, Tensor],
|
||||
*,
|
||||
inference_delay: object,
|
||||
prev_chunk_left_over: object,
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
return inputs, None
|
||||
|
||||
prev_actions = prev_chunk_left_over
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
|
||||
batch_size = state.shape[0]
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
# provided prefix row as a real action constraint, so strip that padding
|
||||
# before constructing the native overlap options.
|
||||
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
|
||||
if valid_prefix_rows.any():
|
||||
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
|
||||
prev_actions = prev_actions[:, :valid_prefix_steps, :]
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
|
||||
action_horizon = int(prev_actions.shape[1])
|
||||
if action_horizon <= 0:
|
||||
return inputs, None
|
||||
|
||||
if prev_actions.shape[2] > max_action_dim:
|
||||
prev_actions = prev_actions[:, :, :max_action_dim]
|
||||
elif prev_actions.shape[2] < max_action_dim:
|
||||
pad = torch.zeros(
|
||||
prev_actions.shape[0],
|
||||
prev_actions.shape[1],
|
||||
max_action_dim - prev_actions.shape[2],
|
||||
dtype=prev_actions.dtype,
|
||||
device=prev_actions.device,
|
||||
)
|
||||
prev_actions = torch.cat([prev_actions, pad], dim=2)
|
||||
|
||||
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
|
||||
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
|
||||
overlap_steps = max(0, min(action_horizon, execution_horizon))
|
||||
if overlap_steps == 0:
|
||||
return inputs, None
|
||||
|
||||
try:
|
||||
frozen_steps = int(inference_delay or 0)
|
||||
except (TypeError, ValueError):
|
||||
frozen_steps = 0
|
||||
frozen_steps = max(0, min(frozen_steps, overlap_steps))
|
||||
|
||||
options = {
|
||||
"action_horizon": action_horizon,
|
||||
"rtc_overlap_steps": overlap_steps,
|
||||
"rtc_frozen_steps": frozen_steps,
|
||||
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
|
||||
}
|
||||
|
||||
inputs = dict(inputs)
|
||||
inputs["action"] = prev_actions
|
||||
return inputs, options
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
@@ -407,54 +248,38 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
|
||||
action-overlap options before calling the underlying model.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
if groot_options is not None:
|
||||
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
|
||||
else:
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
prediction_horizon = self._resolve_prediction_horizon(actions)
|
||||
actions = actions[:, :prediction_horizon]
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
@@ -467,28 +292,40 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/policy_vla_jepa_README.md
|
||||
@@ -1,23 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
@@ -1,337 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _diffusers_available:
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
else:
|
||||
|
||||
class ModelMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class ConfigMixin: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
register_to_config = lambda f: f # noqa: E731
|
||||
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.layer1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
require_package("diffusers", extra="vla_jepa")
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
action_is_pad: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
|
||||
if action_is_pad is None:
|
||||
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full(
|
||||
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||
)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 7
|
||||
n_action_steps: int = 7
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 8
|
||||
|
||||
num_action_tokens_per_timestep: int = 8
|
||||
num_embodied_action_tokens_per_instruction: int = 32
|
||||
num_inference_timesteps: int = 4
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
num_target_vision_tokens: int = 32
|
||||
action_max_seq_len: int = 1024
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 8
|
||||
predictor_depth: int = 12
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
gripper_dim: int = 6
|
||||
gripper_threshold: float = 0.5
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.freeze_qwen and self.enable_world_model:
|
||||
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||
self.enable_world_model = False
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
self.action_dim = self.action_feature.shape[0]
|
||||
if self.robot_state_feature is not None:
|
||||
self.state_dim = self.robot_state_feature.shape[0]
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,629 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoVideoProcessor = None
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
# ============================================================================
|
||||
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
"""
|
||||
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||
|
||||
Components:
|
||||
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||
- DiT-B: flow-matching action head for future action prediction
|
||||
- V-JEPA: world model for video frame prediction
|
||||
|
||||
Input: List[dict] native format (same as original starVLA)
|
||||
- "image": List[PIL.Image] (multi-view images)
|
||||
- "video": np.ndarray [V, T, H, W, 3]
|
||||
- "lang": str (task instruction)
|
||||
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||
- "state": np.ndarray [1, state_dim] (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
require_package("transformers", extra="vla_jepa")
|
||||
self.config = config
|
||||
|
||||
# Vision-language backbone
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
|
||||
# Tokenizer expansion for special action tokens
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||
self.qwen.expand_tokenizer()
|
||||
)
|
||||
|
||||
# Action head (flow-matching DiT)
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = config.jepa_tubelet_size
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
if config.freeze_qwen:
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||
|
||||
Args:
|
||||
examples: List of per-sample dicts with keys:
|
||||
"image" : List[PIL.Image] — multi-view images
|
||||
"video" : np.ndarray [V, T, H, W, 3]
|
||||
"lang" : str — task instruction
|
||||
"action" : np.ndarray [T, action_dim] (optional)
|
||||
"state" : np.ndarray [1, state_dim] (optional)
|
||||
|
||||
Returns:
|
||||
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||
"""
|
||||
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||
actions = [ex["action"] for ex in examples] if has_action else None
|
||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||
state = [ex["state"] for ex in examples] if has_state else None
|
||||
action_is_pad = (
|
||||
[ex["action_is_pad"] for ex in examples]
|
||||
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||
batch_videos = np.stack(batch_videos)
|
||||
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||
|
||||
# Adjust number of views for the world model:
|
||||
# - fewer views than expected: duplicate the first view to fill up
|
||||
# - more views than expected: keep only the first num_views_world_model views
|
||||
num_views_world_model = self.config.jepa_tubelet_size
|
||||
if batch_videos.shape[1] < num_views_world_model:
|
||||
num_missing_views = num_views_world_model - batch_videos.shape[1]
|
||||
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
|
||||
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
|
||||
elif batch_videos.shape[1] > num_views_world_model:
|
||||
batch_videos = batch_videos[:, :num_views_world_model]
|
||||
|
||||
# ---- Step 1: QwenVL encode (same as original) ----
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
action_horizon = self.config.chunk_size
|
||||
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
pad_tensor = torch.stack(
|
||||
[
|
||||
p.to(actions_target.device)
|
||||
if isinstance(p, Tensor)
|
||||
else torch.tensor(p, device=actions_target.device)
|
||||
for p in action_is_pad
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
batch_images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
state: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Native action prediction following original VLA_JEPA.predict_action.
|
||||
|
||||
Args:
|
||||
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||
instructions: Task instructions, one per sample.
|
||||
state: Optional [B, state_dim] numpy array.
|
||||
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
LeRobot adapter for VLA-JEPA.
|
||||
|
||||
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||
back to LeRobot format.
|
||||
"""
|
||||
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
if dataset_meta := kwargs.get("dataset_meta"):
|
||||
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||
ds_features = dataset_meta.features
|
||||
if OBS_STATE in ds_features:
|
||||
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||
if ACTION in ds_features:
|
||||
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
# ---- Format Conversion: LeRobot → Native ----
|
||||
|
||||
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||
"""
|
||||
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||
|
||||
LeRobot format:
|
||||
batch = {
|
||||
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||
"task": str | List[str], (optional instruction)
|
||||
}
|
||||
|
||||
Native format (List[dict]):
|
||||
{
|
||||
"image": List[PIL.Image], # multi-view images per sample
|
||||
"video": np.ndarray [V, T, H, W, 3],
|
||||
"lang": str, # task instruction
|
||||
"action": np.ndarray [T, action_dim], # optional
|
||||
"state": np.ndarray [1, state_dim], # optional
|
||||
}
|
||||
"""
|
||||
# Determine batch size from the first image feature
|
||||
image_keys = list(self.config.image_features.keys())
|
||||
if not image_keys:
|
||||
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||
first_key = image_keys[0]
|
||||
first_tensor = batch[first_key]
|
||||
batch_size = first_tensor.shape[0]
|
||||
|
||||
# ---- Collect images per sample ----
|
||||
# images_per_sample[b][v] = PIL.Image for view v
|
||||
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
# ---- Collect videos per sample ----
|
||||
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||
# Check whether any image feature has a time dimension
|
||||
video_source = None
|
||||
for k in image_keys:
|
||||
if k in batch:
|
||||
video_source = batch[k] # Use first available for shape inspection
|
||||
break
|
||||
|
||||
if video_source is None:
|
||||
raise ValueError("No image data found in batch for video construction.")
|
||||
|
||||
videos_per_sample = []
|
||||
for b in range(batch_size):
|
||||
sample_views = []
|
||||
for k in image_keys:
|
||||
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||
if t.ndim == 3:
|
||||
t = t.unsqueeze(0) # [1, C, H, W]
|
||||
# Convert to [T, H, W, 3] numpy
|
||||
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
# ---- Collect instructions ----
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
instructions = ["Execute the robot action."] * batch_size
|
||||
elif isinstance(tasks, str):
|
||||
instructions = [tasks] * batch_size
|
||||
else:
|
||||
instructions = list(tasks)
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
action_is_pad_list = None
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
action_is_pad_tensor = batch.get("action_is_pad")
|
||||
if action_is_pad_tensor is not None:
|
||||
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Assemble native examples ----
|
||||
examples = []
|
||||
for b in range(batch_size):
|
||||
example = {
|
||||
"image": images_per_sample[b],
|
||||
"video": videos_per_sample[b],
|
||||
"lang": instructions[b],
|
||||
}
|
||||
if actions_list is not None:
|
||||
example["action"] = actions_list[b]
|
||||
if action_is_pad_list is not None:
|
||||
example["action_is_pad"] = action_is_pad_list[b]
|
||||
if state_list is not None:
|
||||
example["state"] = state_list[b]
|
||||
examples.append(example)
|
||||
|
||||
return examples
|
||||
|
||||
# ---- LeRobot Policy Interface ----
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
native_output = self.model.forward(examples)
|
||||
|
||||
ref = next(iter(native_output.values()))
|
||||
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||
logs["loss"] = total_loss.detach().item()
|
||||
return total_loss, logs
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.model.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
examples = self._prepare_model_inputs(batch)
|
||||
batch_images = [ex["image"] for ex in examples]
|
||||
instructions = [ex["lang"] for ex in examples]
|
||||
|
||||
state_np = None
|
||||
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||
state_np = np.stack([ex["state"] for ex in examples])
|
||||
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
return model
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||
class ClipActionsProcessorStep(ProcessorStep):
|
||||
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition = dict(transition)
|
||||
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes a gripper dimension after unnormalization.
|
||||
|
||||
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
|
||||
Only applied when action has more dimensions than gripper_dim.
|
||||
"""
|
||||
|
||||
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
|
||||
self.gripper_dim = gripper_dim
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] > self.gripper_dim:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(
|
||||
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
)
|
||||
)
|
||||
if config.binarize_gripper_action:
|
||||
output_steps.append(
|
||||
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
|
||||
)
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
else:
|
||||
AutoProcessor = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
@@ -1,418 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(
|
||||
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
kd = rotate_queries_or_keys(
|
||||
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||
)
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(
|
||||
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_k.append(
|
||||
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||
)
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("relative_actions_processor")
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
@@ -20,14 +20,12 @@ from .factory import (
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig as RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
"TOPRewardConfig",
|
||||
# Base class
|
||||
|
||||
@@ -25,7 +25,6 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .topreward.configuration_topreward import TOPRewardConfig
|
||||
|
||||
@@ -39,7 +38,7 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm", "robometer", "topreward".
|
||||
"sarm", "topreward".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
@@ -55,10 +54,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "robometer":
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
return RobometerRewardModel
|
||||
elif name == "topreward":
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
@@ -79,7 +74,7 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm", "robometer", "topreward".
|
||||
"reward_classifier", "sarm", "topreward".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -92,8 +87,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
elif reward_type == "robometer":
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
else:
|
||||
@@ -175,13 +168,6 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(reward_cfg, RobometerConfig):
|
||||
from lerobot.rewards.robometer.processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
return make_robometer_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, TOPRewardConfig):
|
||||
from lerobot.rewards.topreward.processor_topreward import make_topreward_pre_post_processors
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_robometer import RobometerConfig
|
||||
from .modeling_robometer import RobometerRewardModel
|
||||
from .processor_robometer import make_robometer_pre_post_processors
|
||||
|
||||
__all__ = ["RobometerConfig", "RobometerRewardModel", "make_robometer_pre_post_processors"]
|
||||
@@ -1,320 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compute per-frame Robometer progress and success curves for a LeRobot dataset.
|
||||
|
||||
For each episode, builds per-frame sub-samples using the frame-steps
|
||||
strategy from the Robometer eval server: for each original frame ``t``,
|
||||
linspace-subsample ``[0, t]`` into ``K`` frames (default 4, matching
|
||||
``NUM_SUBSAMPLED_FRAMES`` in the eval server), run one forward through
|
||||
the Robometer processor + model, and keep the last-frame progress value.
|
||||
All sub-samples are the same size ``K`` so they batch cleanly.
|
||||
|
||||
The parquet uses the same schema as SARM's
|
||||
:mod:`lerobot.rewards.sarm.compute_rabc_weights` so existing consumers —
|
||||
:class:`lerobot.rewards.sarm.rabc.RABCWeights` (which reads
|
||||
``progress_sparse``) and the progress-overlay script in
|
||||
``examples/dataset/create_progress_videos.py`` — work without modification.
|
||||
|
||||
Usage:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes with batching
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
DEFAULT_OUTPUT_FILENAME = "robometer_progress.parquet"
|
||||
|
||||
# Upstream Robometer eval server uses K=4 for frame-steps sub-samples.
|
||||
DEFAULT_NUM_SUBSAMPLED_FRAMES = 4
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
"""Read ``reward_model_path`` from parquet metadata if available."""
|
||||
if not parquet_path.exists():
|
||||
return None
|
||||
try:
|
||||
metadata = pq.read_metadata(parquet_path).schema.to_arrow_schema().metadata
|
||||
if metadata and b"reward_model_path" in metadata:
|
||||
return metadata[b"reward_model_path"].decode()
|
||||
except Exception: # nosec B110
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task(sample: dict[str, Any], default: str) -> str:
|
||||
"""Best-effort task extraction from a dataset sample."""
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task:
|
||||
return task
|
||||
return default
|
||||
|
||||
|
||||
def _build_subsample_indices(num_frames: int, num_subsampled_frames: int) -> list[np.ndarray]:
|
||||
"""Frame-steps linspace expansion.
|
||||
|
||||
For each ``t in [0, num_frames - 1]`` returns ``num_subsampled_frames``
|
||||
indices from ``np.linspace(0, t, num_subsampled_frames)`` — the first
|
||||
and last frames are always included. Each entry is a fixed-size array
|
||||
so the model can batch them.
|
||||
"""
|
||||
return [np.linspace(0, t, num_subsampled_frames).round().astype(np.int64) for t in range(num_frames)]
|
||||
|
||||
|
||||
def compute_robometer_progress(
|
||||
dataset_repo_id: str,
|
||||
reward_model_path: str,
|
||||
output_path: str | None = None,
|
||||
device: str = "cuda",
|
||||
batch_size: int = 32,
|
||||
num_subsampled_frames: int = DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
episodes: list[int] | None = None,
|
||||
image_key: str | None = None,
|
||||
) -> Path:
|
||||
"""Run Robometer over a dataset and write per-frame progress + success."""
|
||||
logging.info(f"Loading Robometer: {reward_model_path}")
|
||||
config = RobometerConfig(pretrained_path=reward_model_path, device=device)
|
||||
if image_key is not None:
|
||||
config.image_key = image_key
|
||||
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
|
||||
model.to(device).eval()
|
||||
|
||||
encoder = RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=num_subsampled_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
)
|
||||
|
||||
image_key = config.image_key
|
||||
|
||||
logging.info(f"Loading dataset: {dataset_repo_id}")
|
||||
dataset = LeRobotDataset(dataset_repo_id, download_videos=True)
|
||||
logging.info(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
episode_indices = list(range(dataset.num_episodes)) if episodes is None else episodes
|
||||
logging.info(f"Processing {len(episode_indices)} episode(s)")
|
||||
|
||||
all_index: list[int] = []
|
||||
all_episode: list[int] = []
|
||||
all_frame: list[int] = []
|
||||
all_progress: list[float] = []
|
||||
|
||||
for episode_idx in tqdm(episode_indices, desc="Episodes"):
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
ep_start = int(ep["dataset_from_index"])
|
||||
ep_end = int(ep["dataset_to_index"])
|
||||
num_frames = ep_end - ep_start
|
||||
if num_frames <= 0:
|
||||
continue
|
||||
|
||||
first_sample = dataset[ep_start]
|
||||
task = _resolve_task(first_sample, default=config.default_task or "perform the task")
|
||||
|
||||
ep_frames = torch.stack([dataset[ep_start + i][image_key] for i in range(num_frames)])
|
||||
|
||||
sub_indices = _build_subsample_indices(num_frames, num_subsampled_frames)
|
||||
|
||||
progress_per_frame = np.zeros(num_frames, dtype=np.float32)
|
||||
|
||||
for start in tqdm(range(0, num_frames, batch_size), desc=f" Ep {episode_idx}", leave=False):
|
||||
end = min(start + batch_size, num_frames)
|
||||
frames_batch = torch.stack([ep_frames[sub_indices[i]] for i in range(start, end)])
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {image_key: frames_batch},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": task},
|
||||
}
|
||||
encoded = encoder(transition)
|
||||
obs = encoded[TransitionKey.OBSERVATION]
|
||||
batch = {
|
||||
key: value.to(device) if isinstance(value, torch.Tensor) else value
|
||||
for key, value in obs.items()
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
rewards = model.compute_reward(batch)
|
||||
progress_per_frame[start:end] = rewards.cpu().numpy()
|
||||
|
||||
for local in range(num_frames):
|
||||
all_index.append(ep_start + local)
|
||||
all_episode.append(episode_idx)
|
||||
all_frame.append(local)
|
||||
all_progress.append(float(progress_per_frame[local]))
|
||||
|
||||
if device.startswith("cuda"):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"index": np.asarray(all_index, dtype=np.int64),
|
||||
"episode_index": np.asarray(all_episode, dtype=np.int64),
|
||||
"frame_index": np.asarray(all_frame, dtype=np.int64),
|
||||
"progress_sparse": np.asarray(all_progress, dtype=np.float32),
|
||||
}
|
||||
).replace_schema_metadata({b"reward_model_path": reward_model_path.encode()})
|
||||
|
||||
out = Path(dataset.root) / DEFAULT_OUTPUT_FILENAME if output_path is None else Path(output_path)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(table, out)
|
||||
logging.info(f"Saved {len(table)} frame values to {out}")
|
||||
|
||||
progress_arr = np.asarray(all_progress, dtype=np.float32)
|
||||
if progress_arr.size:
|
||||
logging.info(
|
||||
f"Progress: mean={float(progress_arr.mean()):.4f}, "
|
||||
f"std={float(progress_arr.std()):.4f}, "
|
||||
f"min={float(progress_arr.min()):.4f}, "
|
||||
f"max={float(progress_arr.max()):.4f}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute per-frame Robometer progress curves for RA-BC weighting.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Dense per-frame progress for one episode
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--episodes 0
|
||||
|
||||
# All episodes, smaller batches for memory-constrained GPUs
|
||||
python -m lerobot.rewards.robometer.compute_rabc_weights \\
|
||||
--dataset-repo-id lerobot/libero_10_image \\
|
||||
--reward-model-path lerobot/Robometer-4B \\
|
||||
--batch-size 16
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="HuggingFace dataset repo id or local path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-model-path", type=str, default=None, help="Robometer checkpoint repo id or local path."
|
||||
)
|
||||
parser.add_argument("--output-path", type=str, default=None, help="Output parquet path.")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device to use (default: cuda).")
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=32, help="Sub-samples per Qwen forward (default: 32)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-subsampled-frames",
|
||||
type=int,
|
||||
default=DEFAULT_NUM_SUBSAMPLED_FRAMES,
|
||||
help=f"Frames per sub-sample (default: {DEFAULT_NUM_SUBSAMPLED_FRAMES}, matches eval server).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes", type=int, nargs="+", default=None, help="Process only these episode indices."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key", type=str, default=None, help="Image observation key (default: from config)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub", action="store_true", help="Upload to the dataset repo on HuggingFace Hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
reward_model_path = args.reward_model_path
|
||||
if reward_model_path is None:
|
||||
temp_dataset = LeRobotDataset(args.dataset_repo_id, download_videos=False)
|
||||
parquet_path = Path(temp_dataset.root) / DEFAULT_OUTPUT_FILENAME
|
||||
reward_model_path = get_reward_model_path_from_parquet(parquet_path)
|
||||
if reward_model_path:
|
||||
logging.info(f"Using reward model from parquet metadata: {reward_model_path}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"--reward-model-path is required (no existing parquet with model metadata found)."
|
||||
)
|
||||
|
||||
output_path = compute_robometer_progress(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
reward_model_path=reward_model_path,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
num_subsampled_frames=args.num_subsampled_frames,
|
||||
episodes=args.episodes,
|
||||
image_key=args.image_key,
|
||||
)
|
||||
|
||||
print(f"\nRobometer progress saved to: {output_path}")
|
||||
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
hub_path = DEFAULT_OUTPUT_FILENAME
|
||||
|
||||
print(f"\nUploading to Hub: {args.dataset_repo_id}/{hub_path}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path),
|
||||
path_in_repo=hub_path,
|
||||
repo_id=args.dataset_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
print(
|
||||
"Successfully uploaded to: "
|
||||
f"https://huggingface.co/datasets/{args.dataset_repo_id}/blob/main/{hub_path}"
|
||||
)
|
||||
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: hf://datasets/{args.dataset_repo_id}/{hub_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
else:
|
||||
print("\nTo use in training, add to your config:")
|
||||
print(" use_rabc: true")
|
||||
print(f" rabc_progress_path: {output_path}")
|
||||
print(" rabc_head_mode: sparse")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,158 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
else:
|
||||
AutoConfig = None # type: ignore[assignment]
|
||||
AutoTokenizer = None # type: ignore[assignment]
|
||||
|
||||
|
||||
# Special tokens Robometer adds to the Qwen-VL tokenizer at construction time.
|
||||
# The order is part of the data contract: upstream resized ``embed_tokens``
|
||||
# after adding these tokens in this exact order, so changing the set or order
|
||||
# would silently misalign the saved embedding rows with their token ids.
|
||||
# ``<|reward_token|>`` and ``<|sim_token|>`` are leftover from earlier upstream
|
||||
# heads (never read at inference) but still occupy rows the checkpoint expects.
|
||||
ROBOMETER_SPECIAL_TOKENS = (
|
||||
"<|split_token|>",
|
||||
"<|reward_token|>",
|
||||
"<|pref_token|>",
|
||||
"<|sim_token|>",
|
||||
"<|prog_token|>",
|
||||
)
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("robometer")
|
||||
@dataclass
|
||||
class RobometerConfig(RewardModelConfig):
|
||||
"""Configuration for the Robometer reward model."""
|
||||
|
||||
pretrained_path: str | None = "lerobot/Robometer-4B"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
|
||||
max_frames: int | None = 8
|
||||
reward_output: str = "progress" # "progress" or "success"
|
||||
success_threshold: float = 0.5
|
||||
|
||||
license: str | None = "apache-2.0"
|
||||
tags: list[str] | None = field(
|
||||
default_factory=lambda: ["reward-model", "vision-language", "qwen3-vl", "zero-shot"]
|
||||
)
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
torch_dtype: str = "bfloat16"
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
average_temporal_patches: bool = True
|
||||
frame_pooling: str = "mean" # "mean" | "boundary" | "attention"
|
||||
frame_pooling_attn_temperature: float = 1.0
|
||||
progress_loss_type: str = "discrete" # "l1" | "l2" | "discrete"
|
||||
progress_discrete_bins: int = 10
|
||||
|
||||
# Serialised Qwen backbone config (post-resize). Always populated by
|
||||
# ``__post_init__`` from ``base_model_id`` + ``len(tokenizer) + 5``, so it
|
||||
# is non-empty after construction. Saved into ``config.json`` automatically
|
||||
# by the base ``_save_pretrained``.
|
||||
vlm_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.reward_output not in {"progress", "success"}:
|
||||
raise ValueError(f"reward_output must be 'progress' or 'success', got {self.reward_output!r}")
|
||||
if self.max_frames is not None and self.max_frames < 1:
|
||||
raise ValueError(f"max_frames must be >= 1, got {self.max_frames}")
|
||||
if self.frame_pooling not in {"mean", "boundary", "attention"}:
|
||||
raise ValueError(f"frame_pooling must be mean/boundary/attention; got {self.frame_pooling!r}")
|
||||
if self.frame_pooling_attn_temperature <= 0:
|
||||
raise ValueError("frame_pooling_attn_temperature must be > 0")
|
||||
if self.progress_loss_type not in {"l1", "l2", "discrete"}:
|
||||
raise ValueError(f"progress_loss_type must be l1/l2/discrete; got {self.progress_loss_type!r}")
|
||||
if self.use_per_frame_progress_token and not self.use_multi_image:
|
||||
raise ValueError("use_per_frame_progress_token=True requires use_multi_image=True")
|
||||
|
||||
if self.image_key not in self.input_features:
|
||||
self.input_features[self.image_key] = PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL)
|
||||
self.output_features.setdefault("progress", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
self.output_features.setdefault("success", PolicyFeature(shape=(1,), type=FeatureType.REWARD))
|
||||
|
||||
# Deterministically populate ``vlm_config`` so it is non-empty after
|
||||
# construction. For ``Qwen/Qwen3-VL-4B-Instruct`` this gives
|
||||
# ``len(tokenizer) + 5 = 151,669 + 5 = 151,674`` — the exact post-resize
|
||||
# vocab the published ``Robometer-4B`` checkpoint was saved with.
|
||||
if not self.vlm_config:
|
||||
require_package("transformers", extra="robometer")
|
||||
vlm = AutoConfig.from_pretrained(self.base_model_id).to_dict()
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_model_id)
|
||||
text_config = vlm.get("text_config")
|
||||
if not isinstance(text_config, dict):
|
||||
raise ValueError(
|
||||
f"Backbone config for {self.base_model_id!r} has no nested `text_config`; "
|
||||
"Robometer expects a Qwen-VL-style config."
|
||||
)
|
||||
text_config["vocab_size"] = len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)
|
||||
self.vlm_config = vlm
|
||||
|
||||
@property
|
||||
def use_discrete_progress(self) -> bool:
|
||||
"""Whether the progress head outputs distribution logits over bins."""
|
||||
return self.progress_loss_type.lower() == "discrete"
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self):
|
||||
"""Reconstruct the Qwen backbone config from :attr:`vlm_config`."""
|
||||
require_package("transformers", extra="robometer")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
model_type = config_dict.pop("model_type", None)
|
||||
if model_type is None:
|
||||
raise ValueError("vlm_config must include `model_type` to reconstruct the backbone config")
|
||||
return AutoConfig.for_model(model_type, **config_dict)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.image_key not in self.input_features:
|
||||
raise ValueError(f"Robometer requires image input feature {self.image_key!r}")
|
||||
@@ -1,481 +0,0 @@
|
||||
# Copyright 2026 Anthony Liang, Yigit Korkmaz, Stephen Tu, Erdem Bıyık, Jesse Zhang
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""ROBOMETER: Scaling General-Purpose Robotic Reward Models via Trajectory Comparisons.
|
||||
|
||||
Paper: https://arxiv.org/abs/2603.02115
|
||||
Project: https://robometer.github.io
|
||||
Original code: https://github.com/aliang8/robometer
|
||||
Model: https://huggingface.co/robometer/Robometer-4B
|
||||
|
||||
Robometer is a general-purpose, video-language-input reward model built on
|
||||
``Qwen/Qwen3-VL-4B-Instruct``. It is trained with a dual reward-prediction
|
||||
objective:
|
||||
|
||||
- A frame-level progress loss anchoring reward magnitude on expert data.
|
||||
- A trajectory-comparison preference loss imposing global ordering constraints
|
||||
across trajectories sharing the same instruction.
|
||||
|
||||
To support downstream RL it also predicts a frame-level binary success. The
|
||||
training prompt inserts three learnable tokens:
|
||||
|
||||
- ``<|prog_token|>`` after each frame to read per-frame progress and success.
|
||||
- ``<|pref_token|>`` at the end to read pairwise preference (training-only).
|
||||
- ``<|split_token|>`` between two trajectories in preference samples
|
||||
(training-only).
|
||||
|
||||
Progress is modeled as a categorical distribution over ``progress_discrete_bins``
|
||||
uniformly-spaced centers in ``[0, 1]`` (C51-style), and the continuous estimate
|
||||
is recovered as the softmax-weighted mean of those centers — see
|
||||
:func:`convert_bins_to_continuous`.
|
||||
|
||||
This LeRobot port is **inference-only**: the preference head is preserved in
|
||||
the state dict for byte-equivalence with the published ``Robometer-4B``
|
||||
checkpoint but is not queried by :meth:`RobometerRewardModel.compute_reward`,
|
||||
which returns the last-frame progress (clamped to ``[0, 1]``) or sigmoid'd
|
||||
success probability depending on :attr:`RobometerConfig.reward_output`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.robometer.configuration_robometer import RobometerConfig
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModelForImageTextToText
|
||||
else:
|
||||
AutoModelForImageTextToText = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Namespace for Robometer's pre-encoded Qwen-VL observation tensors.
|
||||
ROBOMETER_FEATURE_PREFIX = f"{OBS_PREFIX}robometer."
|
||||
ROBOMETER_QWEN_INPUT_KEYS = (
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"pixel_values_videos",
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"mm_token_type_ids",
|
||||
)
|
||||
ROBOMETER_METADATA_KEYS = (
|
||||
"prog_token_id",
|
||||
"vision_start_token_id",
|
||||
"vision_end_token_id",
|
||||
"video_merge_size",
|
||||
)
|
||||
ROBOMETER_INPUT_KEYS = ROBOMETER_QWEN_INPUT_KEYS + ROBOMETER_METADATA_KEYS
|
||||
|
||||
|
||||
def convert_bins_to_continuous(bin_logits: Tensor) -> Tensor:
|
||||
"""Collapse per-bin logits into a single value in ``[0, 1]``.
|
||||
|
||||
The discrete progress head outputs ``num_bins`` logits per frame. Bins are
|
||||
evenly spaced centers in ``[0, 1]``; the continuous prediction is the
|
||||
softmax-weighted mean of those centers.
|
||||
"""
|
||||
bin_probs = torch.softmax(bin_logits, dim=-1)
|
||||
num_bins = bin_logits.shape[-1]
|
||||
bin_centers = torch.linspace(0.0, 1.0, num_bins, device=bin_logits.device, dtype=bin_logits.dtype)
|
||||
return (bin_probs * bin_centers).sum(dim=-1)
|
||||
|
||||
|
||||
def _squeeze_last_safe(x: Tensor) -> Tensor:
|
||||
"""Drop a trailing singleton dim only when present."""
|
||||
return x.squeeze(-1) if x.ndim > 1 and x.shape[-1] == 1 else x
|
||||
|
||||
|
||||
def _torch_dtype(name: str) -> torch.dtype:
|
||||
dtype = getattr(torch, name, None)
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
raise ValueError(f"Unknown torch dtype: {name!r}")
|
||||
|
||||
|
||||
class RobometerPredictionHead(nn.Sequential):
|
||||
"""Small MLP head used for Robometer's progress / success / preference outputs."""
|
||||
|
||||
def __init__(self, hidden_dim: int, output_size: int, *, dropout: float, with_sigmoid: bool) -> None:
|
||||
layers: list[nn.Module] = [
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, output_size),
|
||||
]
|
||||
if with_sigmoid:
|
||||
layers.append(nn.Sigmoid())
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
def decode_progress_outputs(
|
||||
progress_logits: Tensor | None,
|
||||
success_logits: Tensor | None,
|
||||
*,
|
||||
is_discrete_mode: bool,
|
||||
) -> dict[str, list[list[float]]]:
|
||||
"""Decode RBM head outputs into per-frame floats.
|
||||
|
||||
Args:
|
||||
progress_logits: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
success_logits: ``(B, T)`` raw logits, ``sigmoid``-ed to probabilities.
|
||||
is_discrete_mode: if True the progress logits get a softmax over bins
|
||||
and are projected onto bin centers via :func:`convert_bins_to_continuous`.
|
||||
|
||||
Returns:
|
||||
Dict with ``progress_pred`` and ``success_probs``, each a list of
|
||||
length ``B`` of per-frame float lists.
|
||||
"""
|
||||
progress_pred: list[list[float]] = []
|
||||
success_probs: list[list[float]] = []
|
||||
|
||||
if progress_logits is not None:
|
||||
for sample_logits in progress_logits:
|
||||
if is_discrete_mode:
|
||||
continuous = convert_bins_to_continuous(sample_logits.detach().float().cpu())
|
||||
progress_pred.append(continuous.flatten().tolist())
|
||||
else:
|
||||
progress_pred.append(sample_logits.detach().float().cpu().flatten().tolist())
|
||||
|
||||
if success_logits is not None:
|
||||
for sample_logits in success_logits:
|
||||
success_probs.append(torch.sigmoid(sample_logits.detach().float().cpu()).flatten().tolist())
|
||||
|
||||
return {"progress_pred": progress_pred, "success_probs": success_probs}
|
||||
|
||||
|
||||
class RobometerRewardModel(PreTrainedRewardModel):
|
||||
"""Robometer (RBM) reward model — inference-only LeRobot port.
|
||||
|
||||
Wraps a Qwen-VL backbone (default: ``Qwen/Qwen3-VL-4B-Instruct``) with three
|
||||
prediction heads from the paper (progress, success, preference). At
|
||||
inference time only the progress and success heads are queried; the
|
||||
preference head is kept on the module so the published ``Robometer-4B``
|
||||
safetensors load unchanged.
|
||||
"""
|
||||
|
||||
name = "robometer"
|
||||
config_class = RobometerConfig
|
||||
|
||||
def __init__(self, config: RobometerConfig, *, dropout: float = 0.1) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Two backbone-build paths (EO-1 style, branched on ``pretrained_path``):
|
||||
#
|
||||
# - Fresh training (``pretrained_path is None``): download the base
|
||||
# Qwen weights and resize the embed table to match
|
||||
# ``vlm_config.text_config.vocab_size`` — populated deterministically
|
||||
# in ``RobometerConfig.__post_init__`` as
|
||||
# ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``
|
||||
#
|
||||
# - Loading a saved checkpoint (``pretrained_path`` is set): rebuild
|
||||
# the empty architecture from ``vlm_config`` via
|
||||
# ``AutoModelForImageTextToText.from_config`` so the subsequent
|
||||
# ``model.safetensors`` load is a direct fill of the right shape —
|
||||
# no redundant Qwen weight download.
|
||||
torch_dtype = _torch_dtype(config.torch_dtype)
|
||||
if config.pretrained_path is None:
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
config.base_model_id,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
target_vocab = config.vlm_config["text_config"]["vocab_size"]
|
||||
self.model.resize_token_embeddings(target_vocab)
|
||||
else:
|
||||
self.model = AutoModelForImageTextToText.from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# All Qwen-VL backbones Robometer supports expose `text_config.hidden_size`.
|
||||
# Falls back to the top-level `hidden_size` so future non-multimodal
|
||||
# variants would still resolve.
|
||||
backbone_config = self.model.config
|
||||
text_config = getattr(backbone_config, "text_config", None)
|
||||
hidden_size = getattr(text_config, "hidden_size", None) if text_config is not None else None
|
||||
if hidden_size is None:
|
||||
hidden_size = getattr(backbone_config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise AttributeError(
|
||||
f"Could not infer hidden_size from backbone config of {config.base_model_id}"
|
||||
)
|
||||
hidden_dim = int(hidden_size)
|
||||
|
||||
# Robometer's three prediction heads + frame-pool attention.
|
||||
progress_output = config.progress_discrete_bins if config.use_discrete_progress else 1
|
||||
self.progress_head = RobometerPredictionHead(
|
||||
hidden_dim,
|
||||
progress_output,
|
||||
dropout=dropout,
|
||||
with_sigmoid=not config.use_discrete_progress,
|
||||
)
|
||||
self.preference_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.success_head = RobometerPredictionHead(hidden_dim, 1, dropout=dropout, with_sigmoid=False)
|
||||
self.frame_pool_attn = nn.Linear(hidden_dim, 1, bias=False)
|
||||
|
||||
# Match the dtype of the loaded base model so weight loading is a no-op cast.
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
self.progress_head.to(dtype=model_dtype)
|
||||
self.preference_head.to(dtype=model_dtype)
|
||||
self.success_head.to(dtype=model_dtype)
|
||||
self.frame_pool_attn.to(dtype=model_dtype)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = {
|
||||
key: batch[f"{ROBOMETER_FEATURE_PREFIX}{key}"]
|
||||
for key in ROBOMETER_INPUT_KEYS
|
||||
if f"{ROBOMETER_FEATURE_PREFIX}{key}" in batch
|
||||
}
|
||||
if "input_ids" not in inputs:
|
||||
raise KeyError(
|
||||
f"Robometer batch missing pre-encoded inputs (expected "
|
||||
f"`{ROBOMETER_FEATURE_PREFIX}input_ids`). Make sure the "
|
||||
"RobometerEncoderProcessorStep ran before `compute_reward`."
|
||||
)
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items()}
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
progress_logits, success_logits = self._compute_rbm_logits(inputs)
|
||||
|
||||
decoded = decode_progress_outputs(
|
||||
progress_logits,
|
||||
success_logits,
|
||||
is_discrete_mode=self.config.use_discrete_progress,
|
||||
)
|
||||
values = (
|
||||
decoded["success_probs"] if self.config.reward_output == "success" else decoded["progress_pred"]
|
||||
)
|
||||
|
||||
rewards = torch.stack([torch.as_tensor(seq, dtype=torch.float32)[-1] for seq in values])
|
||||
if self.config.reward_output == "success":
|
||||
rewards = (rewards > self.config.success_threshold).float()
|
||||
else:
|
||||
# Match upstream Robometer's ``extract_rewards_from_output``: per-frame
|
||||
# progress predictions are clamped to ``[0, 1]`` before being returned.
|
||||
rewards = rewards.clamp(0.0, 1.0)
|
||||
return rewards.to(self.config.device or "cpu")
|
||||
|
||||
def _compute_rbm_logits(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Run the Qwen3-VL backbone and apply Robometer's heads.
|
||||
|
||||
``inputs`` is the encoded batch produced by
|
||||
:class:`RobometerEncoderProcessorStep`. It carries Qwen tensors as well
|
||||
as Robometer-specific metadata (``prog_token_id``,
|
||||
``vision_start_token_id``, ``vision_end_token_id``, ``video_merge_size``)
|
||||
— the metadata is popped here so the rest can be forwarded straight to
|
||||
the Qwen model.
|
||||
|
||||
Returns ``(progress_logits, success_logits)``. Shapes:
|
||||
|
||||
- ``progress_logits``: ``(B, T)`` (continuous) or ``(B, T, num_bins)`` (discrete).
|
||||
- ``success_logits``: ``(B, T)`` raw logits (sigmoid happens at decode time).
|
||||
"""
|
||||
prog_token_id = inputs.pop("prog_token_id", None)
|
||||
vision_start_token_id = inputs.pop("vision_start_token_id", None)
|
||||
vision_end_token_id = inputs.pop("vision_end_token_id", None)
|
||||
video_merge_size = inputs.pop("video_merge_size", 14)
|
||||
|
||||
# Qwen3-VL doesn't reliably populate `last_hidden_state`; ask for the
|
||||
# full hidden-state tuple and take the last layer. This matches the
|
||||
# `is_qwen3` path in upstream Robometer's `RBM.forward_qwen` (main).
|
||||
outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
hidden_state = (
|
||||
outputs.hidden_states[-1]
|
||||
if getattr(outputs, "hidden_states", None)
|
||||
else outputs.last_hidden_state
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
if self.config.use_per_frame_progress_token:
|
||||
if prog_token_id is None:
|
||||
raise KeyError("`prog_token_id` missing in batch (run RobometerEncoderProcessorStep first)")
|
||||
return self._process_token_extraction(hidden_state, input_ids, prog_token_id=prog_token_id)
|
||||
if self.config.use_multi_image:
|
||||
if vision_start_token_id is None or vision_end_token_id is None:
|
||||
raise KeyError(
|
||||
"`vision_start_token_id` / `vision_end_token_id` missing in batch "
|
||||
"(run RobometerEncoderProcessorStep first)"
|
||||
)
|
||||
return self._process_multi_image_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
start_id=vision_start_token_id,
|
||||
end_id=vision_end_token_id,
|
||||
)
|
||||
video_grid_thw = inputs.get("video_grid_thw")
|
||||
if video_grid_thw is None:
|
||||
raise ValueError("video_grid_thw is required for video-mode Robometer inference")
|
||||
if vision_start_token_id is None:
|
||||
raise KeyError("`vision_start_token_id` missing in batch")
|
||||
return self._process_video_frames(
|
||||
hidden_state,
|
||||
input_ids,
|
||||
video_grid_thw,
|
||||
start_id=vision_start_token_id,
|
||||
merge_size=video_merge_size,
|
||||
)
|
||||
|
||||
def _apply_heads_to_hidden_states(self, frame_embeddings: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Apply progress + success heads to a tensor of frame embeddings."""
|
||||
progress_out = self.progress_head(frame_embeddings)
|
||||
progress = progress_out if self.config.use_discrete_progress else _squeeze_last_safe(progress_out)
|
||||
success = _squeeze_last_safe(self.success_head(frame_embeddings))
|
||||
return progress, success
|
||||
|
||||
def _process_token_extraction(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
prog_token_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success from ``<|prog_token|>`` positions."""
|
||||
token_mask = input_ids == prog_token_id
|
||||
batch_indices, positions = token_mask.nonzero(as_tuple=True)
|
||||
if positions.numel() == 0:
|
||||
raise ValueError("`<|prog_token|>` not found in any sequence")
|
||||
|
||||
per_sample_hidden = [
|
||||
hidden_state[i, positions[batch_indices == i]] for i in range(input_ids.shape[0])
|
||||
]
|
||||
progress_list, success_list = [], []
|
||||
for embeddings in per_sample_hidden:
|
||||
if embeddings.shape[0] == 0:
|
||||
raise ValueError("`<|prog_token|>` missing in a sequence")
|
||||
progress, success = self._apply_heads_to_hidden_states(embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _process_multi_image_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in multi-image mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
frame_embeddings = self._extract_hidden_states_from_token_pairs(
|
||||
seq_hidden, seq_ids, start_id, end_id
|
||||
)
|
||||
progress, success = self._apply_heads_to_hidden_states(frame_embeddings)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
|
||||
def _extract_hidden_states_from_token_pairs(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
start_id: int,
|
||||
end_id: int,
|
||||
) -> Tensor:
|
||||
start_positions = (input_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
end_positions = (input_ids == end_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
if start_positions.numel() != end_positions.numel():
|
||||
raise ValueError(
|
||||
f"Mismatched vision token counts: {start_positions.numel()} start vs "
|
||||
f"{end_positions.numel()} end"
|
||||
)
|
||||
|
||||
frames: list[Tensor] = []
|
||||
for start, end in zip(start_positions.tolist(), end_positions.tolist(), strict=True):
|
||||
if start >= end:
|
||||
raise ValueError(f"Invalid vision token pair: start={start} end={end}")
|
||||
patch_tokens = hidden_state[start + 1 : end]
|
||||
if patch_tokens.shape[0] == 0:
|
||||
frames.append((hidden_state[start] + hidden_state[end]) / 2.0)
|
||||
continue
|
||||
|
||||
pooling = self.config.frame_pooling
|
||||
if pooling == "mean":
|
||||
frames.append(patch_tokens.mean(dim=0))
|
||||
elif pooling == "boundary":
|
||||
frames.append(patch_tokens[-1])
|
||||
else: # attention
|
||||
scores = (
|
||||
self.frame_pool_attn(patch_tokens).squeeze(-1)
|
||||
/ self.config.frame_pooling_attn_temperature
|
||||
)
|
||||
weights = torch.softmax(scores, dim=0).unsqueeze(-1)
|
||||
frames.append((weights * patch_tokens).sum(dim=0))
|
||||
|
||||
return torch.stack(frames)
|
||||
|
||||
def _process_video_frames(
|
||||
self,
|
||||
hidden_state: Tensor,
|
||||
input_ids: Tensor,
|
||||
video_grid_thw: Tensor,
|
||||
*,
|
||||
start_id: int,
|
||||
merge_size: int,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Per-frame progress/success in video mode (Qwen-VL)."""
|
||||
progress_list, success_list = [], []
|
||||
for batch_idx in range(input_ids.shape[0]):
|
||||
seq_ids = input_ids[batch_idx]
|
||||
seq_hidden = hidden_state[batch_idx]
|
||||
start_positions = (seq_ids == start_id).nonzero(as_tuple=True)[0]
|
||||
if start_positions.numel() == 0:
|
||||
raise ValueError("`<|vision_start|>` not found in sequence")
|
||||
t_dim, h_dim, w_dim = (int(x) for x in video_grid_thw[batch_idx].tolist())
|
||||
tokens_per_frame = (h_dim * w_dim) // (merge_size**2)
|
||||
|
||||
cursor = start_positions[0].item()
|
||||
frame_embeddings: list[Tensor] = []
|
||||
for _ in range(t_dim):
|
||||
if self.config.average_temporal_patches:
|
||||
patch = seq_hidden[cursor : cursor + tokens_per_frame]
|
||||
frame_embeddings.append(patch.mean(dim=0))
|
||||
else:
|
||||
frame_embeddings.append(seq_hidden[cursor + tokens_per_frame])
|
||||
cursor += tokens_per_frame
|
||||
|
||||
stacked = torch.stack(frame_embeddings)
|
||||
progress, success = self._apply_heads_to_hidden_states(stacked)
|
||||
progress_list.append(progress)
|
||||
success_list.append(success)
|
||||
|
||||
return torch.stack(progress_list), torch.stack(success_list)
|
||||
@@ -1,338 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Robometer pre/post processing pipelines."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
policy_action_to_transition,
|
||||
)
|
||||
from lerobot.rewards.robometer.configuration_robometer import (
|
||||
ROBOMETER_SPECIAL_TOKENS,
|
||||
RobometerConfig,
|
||||
)
|
||||
from lerobot.rewards.robometer.modeling_robometer import ROBOMETER_FEATURE_PREFIX
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor
|
||||
else:
|
||||
AutoProcessor = None
|
||||
|
||||
PROGRESS_PROMPT = (
|
||||
"The task for the robot is '{task}'. Given the trajectory video, predict "
|
||||
"the task progress at each frame, how far along the robot is towards "
|
||||
"completing the task, a float between 0 and 1, where 0 is the starting "
|
||||
"state and 1 is when the task is completed. If the robot is not "
|
||||
"performing the same task, predict 0 progress."
|
||||
)
|
||||
|
||||
|
||||
def _frames_to_pil(frames: np.ndarray) -> list[Image.Image]:
|
||||
"""Convert ``(T, H, W, C)`` uint8 frames to a list of PIL images."""
|
||||
if frames.ndim != 4:
|
||||
raise ValueError(f"Expected (T,H,W,C) frames; got shape {frames.shape}")
|
||||
if frames.dtype != np.uint8:
|
||||
frames = np.clip(frames, 0, 255).astype(np.uint8)
|
||||
return [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
|
||||
|
||||
|
||||
def _video_to_numpy(video: Tensor, *, max_frames: int | None) -> np.ndarray:
|
||||
"""Convert one trajectory tensor to a ``(T, H, W, C) uint8`` numpy array."""
|
||||
if max_frames is not None:
|
||||
video = video[-max_frames:]
|
||||
if video.shape[1] in (1, 3):
|
||||
video = video.permute(0, 2, 3, 1)
|
||||
elif video.shape[-1] not in (1, 3):
|
||||
raise ValueError(f"Expected channel dim of size 1 or 3, got shape {tuple(video.shape)}")
|
||||
|
||||
array = video.detach().cpu().numpy()
|
||||
if np.issubdtype(array.dtype, np.floating) and array.size > 0 and array.max() <= 1.0:
|
||||
array = array * 255.0
|
||||
return np.clip(array, 0, 255).astype(np.uint8)
|
||||
|
||||
|
||||
def _expand_tasks(task: Any, *, batch_size: int, default: str | None) -> list[str]:
|
||||
if task is None:
|
||||
task = default
|
||||
if task is None:
|
||||
raise KeyError("Robometer expected a task description in complementary data")
|
||||
if isinstance(task, str):
|
||||
return [task] * batch_size
|
||||
if isinstance(task, tuple):
|
||||
task = list(task)
|
||||
if not (isinstance(task, list) and all(isinstance(item, str) for item in task)):
|
||||
raise TypeError(f"Robometer task must be a string or list of strings, got {type(task)}")
|
||||
if len(task) == 1 and batch_size > 1:
|
||||
return task * batch_size
|
||||
if len(task) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} tasks, got {len(task)}")
|
||||
return task
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robometer_encoder")
|
||||
class RobometerEncoderProcessorStep(ProcessorStep):
|
||||
"""Encode raw frames + task into Qwen-VL tensors for the Robometer model.
|
||||
|
||||
Loads a :class:`~transformers.AutoProcessor` matching ``base_model_id`` and
|
||||
registers Robometer's special tokens on the tokenizer. The matching
|
||||
embedding resize happens model-side in
|
||||
:meth:`RobometerRewardModel.__init__`.
|
||||
|
||||
At call time the step reads:
|
||||
|
||||
- ``observation[image_key]``: ``(B, T, C, H, W)`` or ``(B, C, H, W)`` frames.
|
||||
- ``complementary_data[task_key]``: a string or list of strings.
|
||||
|
||||
and writes ``observation[f"{ROBOMETER_FEATURE_PREFIX}<name>"]`` for:
|
||||
|
||||
- the Qwen-VL processor outputs: ``input_ids``, ``attention_mask``,
|
||||
``pixel_values``, ``image_grid_thw``, ``video_grid_thw``, ...
|
||||
- Robometer-specific token ids consumed by the model heads:
|
||||
``prog_token_id``, ``vision_start_token_id``, ``vision_end_token_id``,
|
||||
``video_merge_size``.
|
||||
"""
|
||||
|
||||
base_model_id: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
image_key: str = OBS_IMAGES + ".top"
|
||||
task_key: str = "task"
|
||||
default_task: str | None = None
|
||||
max_frames: int | None = 8
|
||||
use_multi_image: bool = True
|
||||
use_per_frame_progress_token: bool = True
|
||||
max_length: int = 1024
|
||||
|
||||
_processor: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
require_package("transformers", extra="robometer")
|
||||
require_package("qwen-vl-utils", extra="robometer", import_name="qwen_vl_utils")
|
||||
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
self.base_model_id,
|
||||
trust_remote_code=True,
|
||||
do_sample_frames=False,
|
||||
padding_side="right",
|
||||
)
|
||||
|
||||
# Register Robometer's special tokens on the tokenizer. The matching
|
||||
# embedding resize happens model-side in `RobometerRewardModel.__init__`.
|
||||
tokenizer = self._processor.tokenizer
|
||||
# Qwen tokenizers may not define a pad token, but batched prompts/videos
|
||||
# require padding, so reuse EOS as the padding token.
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
for token in ROBOMETER_SPECIAL_TOKENS:
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [token]})
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("RobometerEncoderProcessorStep requires an observation dict")
|
||||
|
||||
if self.image_key not in observation:
|
||||
raise KeyError(f"Robometer expected image key {self.image_key!r} in observation")
|
||||
|
||||
frames = observation[self.image_key]
|
||||
tensor = frames.detach().cpu() if isinstance(frames, Tensor) else torch.as_tensor(frames)
|
||||
if tensor.ndim == 4:
|
||||
tensor = tensor.unsqueeze(1)
|
||||
elif tensor.ndim != 5:
|
||||
raise ValueError(
|
||||
f"Expected Robometer frames with shape (B,C,H,W) or (B,T,C,H,W); got {tuple(tensor.shape)}"
|
||||
)
|
||||
|
||||
batch_size = tensor.shape[0]
|
||||
tasks = _expand_tasks(
|
||||
complementary.get(self.task_key, self.default_task),
|
||||
batch_size=batch_size,
|
||||
default=self.default_task,
|
||||
)
|
||||
|
||||
samples = [
|
||||
(_video_to_numpy(tensor[i], max_frames=self.max_frames), tasks[i]) for i in range(batch_size)
|
||||
]
|
||||
encoded = self.encode_samples(samples)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for key, value in encoded.items():
|
||||
new_observation[f"{ROBOMETER_FEATURE_PREFIX}{key}"] = value
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def encode_samples(self, samples: list[tuple[np.ndarray, str]]) -> dict[str, Tensor]:
|
||||
"""Run the Qwen-VL processor on a list of ``(frames, task)`` samples."""
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
conversations = [self._build_conversation(frames, task) for frames, task in samples]
|
||||
|
||||
texts = [
|
||||
self._processor.apply_chat_template(
|
||||
msg,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
add_vision_id=True,
|
||||
enable_thinking=False,
|
||||
fps=1,
|
||||
)
|
||||
for msg in conversations
|
||||
]
|
||||
|
||||
process_kwargs: dict[str, Any] = {
|
||||
"return_video_kwargs": True,
|
||||
"return_video_metadata": True,
|
||||
}
|
||||
image_processor = getattr(self._processor, "image_processor", None)
|
||||
if image_processor is not None and hasattr(image_processor, "patch_size"):
|
||||
process_kwargs["image_patch_size"] = image_processor.patch_size
|
||||
|
||||
image_inputs, video_inputs, video_kwargs = process_vision_info(conversations, **process_kwargs)
|
||||
|
||||
videos: list[Any] | None = None
|
||||
video_metadatas: list[Any] | None = None
|
||||
if video_inputs:
|
||||
if isinstance(video_inputs[0], tuple) and len(video_inputs[0]) == 2:
|
||||
videos_seq, metadatas_seq = zip(*video_inputs, strict=False)
|
||||
videos = list(videos_seq)
|
||||
video_metadatas = list(metadatas_seq)
|
||||
else:
|
||||
videos = list(video_inputs)
|
||||
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
"text": texts,
|
||||
"images": image_inputs,
|
||||
"padding": True,
|
||||
"truncation": False,
|
||||
"max_length": self.max_length,
|
||||
"return_tensors": "pt",
|
||||
"do_resize": False,
|
||||
}
|
||||
if videos is not None:
|
||||
processor_kwargs["videos"] = videos
|
||||
if video_metadatas is not None:
|
||||
processor_kwargs["video_metadata"] = video_metadatas
|
||||
if video_kwargs:
|
||||
processor_kwargs.update(video_kwargs)
|
||||
|
||||
encoded = self._processor(**processor_kwargs)
|
||||
|
||||
# Write Robometer-specific token ids and the video patch merge size into
|
||||
# the encoded batch so `RobometerRewardModel` doesn't need its own
|
||||
# tokenizer at inference (EO1-style separation: the processor owns the
|
||||
# tokenizer, the model owns the backbone and heads).
|
||||
tokenizer = self._processor.tokenizer
|
||||
encoded["prog_token_id"] = tokenizer.convert_tokens_to_ids("<|prog_token|>")
|
||||
encoded["vision_start_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
||||
encoded["vision_end_token_id"] = tokenizer.convert_tokens_to_ids("<|vision_end|>")
|
||||
video_processor = getattr(self._processor, "video_processor", None)
|
||||
encoded["video_merge_size"] = int(getattr(video_processor, "merge_size", 14))
|
||||
return encoded
|
||||
|
||||
def _build_conversation(self, frames: np.ndarray, task: str) -> list[dict[str, Any]]:
|
||||
pil_frames = _frames_to_pil(frames)
|
||||
prompt = PROGRESS_PROMPT.format(task=task)
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
||||
|
||||
if self.use_multi_image:
|
||||
for image in pil_frames:
|
||||
content.append({"type": "image", "image": image})
|
||||
if self.use_per_frame_progress_token:
|
||||
content.append({"type": "text", "text": "<|prog_token|>"})
|
||||
else:
|
||||
content.append({"type": "video", "video": pil_frames, "sample_fps": 1.0})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"base_model_id": self.base_model_id,
|
||||
"image_key": self.image_key,
|
||||
"task_key": self.task_key,
|
||||
"default_task": self.default_task,
|
||||
"max_frames": self.max_frames,
|
||||
"use_multi_image": self.use_multi_image,
|
||||
"use_per_frame_progress_token": self.use_per_frame_progress_token,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
|
||||
def make_robometer_pre_post_processors(
|
||||
config: RobometerConfig,
|
||||
dataset_stats: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Pipeline that pre-encodes frames + task into Qwen-VL tensors.
|
||||
|
||||
The preprocessor adds a batch dimension if needed, runs Robometer's
|
||||
encoder, and moves everything to the configured device. The
|
||||
postprocessor is the identity since Robometer outputs a single reward
|
||||
tensor.
|
||||
"""
|
||||
del dataset_stats # Robometer has its own normalisation inside the Qwen-VL processor.
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RobometerEncoderProcessorStep(
|
||||
base_model_id=config.base_model_id,
|
||||
image_key=config.image_key,
|
||||
task_key=config.task_key,
|
||||
default_task=config.default_task,
|
||||
max_frames=config.max_frames,
|
||||
use_multi_image=config.use_multi_image,
|
||||
use_per_frame_progress_token=config.use_per_frame_progress_token,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -292,8 +292,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -301,31 +312,24 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
preprocessor_overrides = {
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
postprocessor_overrides = {
|
||||
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
||||
"rename_map": cfg.rename_map
|
||||
}
|
||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||
"unnormalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": policy.config.output_features,
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
}
|
||||
if getattr(active_cfg, "use_relative_actions", False):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": getattr(active_cfg, "action_feature_names", None),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
@@ -337,6 +341,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||
{% elif model_name == "sarm" %}
|
||||
A Success-Aware Reward Model (SARM) predicts a dense reward signal from observations, typically used downstream for reinforcement learning or human-in-the-loop fine-tuning when task success is not directly observable.
|
||||
{% elif model_name == "robometer" %}
|
||||
ROBOMETER is a general-purpose video-language robotic reward model built on a fine-tuned Qwen3-VL-4B backbone with progress, preference, and success heads. Given a trajectory video and a task description, it predicts dense, frame-level task progress in [0, 1] and frame-level success probabilities for downstream robot learning, including offline RL, online RL, data filtering and retrieval, and automated failure detection.
|
||||
{% elif model_name == "topreward" %}
|
||||
TOPReward is a **zero-shot** reward model that extracts token log-probabilities from an off-the-shelf vision-language model (default Qwen3-VL) as a reward signal. Given a video trajectory and a task instruction, it returns the VLM's log-likelihood of the instruction being true, with no fine-tuning required.
|
||||
{% else %}
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# Local-only parity artifacts (regenerated via dump_original_n1_7.py); never committed.
|
||||
*.npz
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
|
||||
"""Test script for LeRobot's Groot policy forward and inference passes."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
@@ -41,20 +41,13 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
# Define constants for dummy data (GR00T N1.7 native conventions).
|
||||
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
|
||||
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
|
||||
# matches the model's native action space.
|
||||
# Define constants for dummy data
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 40
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = auto_select_torch_device()
|
||||
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
|
||||
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
|
||||
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
|
||||
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
|
||||
EMBODIMENT_TAG = "gr1_unified"
|
||||
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
@@ -95,13 +88,13 @@ def instantiate_lerobot_groot(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = EMBODIMENT_TAG
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
@@ -109,7 +102,7 @@ def instantiate_lerobot_groot(
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag=EMBODIMENT_TAG,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
@@ -155,8 +148,8 @@ def create_dummy_data(device=DEVICE):
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_inference():
|
||||
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
|
||||
print("Test: LeRobot GR00T N1.7 Inference Pass")
|
||||
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
|
||||
print("Test: LeRobot Groot Inference Pass")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
@@ -188,9 +181,9 @@ def test_lerobot_groot_inference():
|
||||
|
||||
@require_cuda
|
||||
def test_lerobot_groot_forward_pass():
|
||||
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
|
||||
"""Test the forward pass of LeRobot's Groot policy."""
|
||||
print("\n" + "=" * 50)
|
||||
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
|
||||
print("Test: LeRobot Groot Forward Pass (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,194 +14,431 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
|
||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
||||
in the checkpoint.
|
||||
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
||||
to per-tag ``.npz`` files.
|
||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
||||
model, and compares.
|
||||
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||
for the full run procedure.
|
||||
"""
|
||||
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
|
||||
pytest.importorskip("gr00t")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="Requires a local GR00T N1.7 checkpoint + pre-generated artifacts; not for CI.",
|
||||
reason="This test requires local Groot installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
from gr00t.data.dataset import ModalityConfig # noqa: E402
|
||||
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
|
||||
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
|
||||
from gr00t.model.policy import Gr00tPolicy # noqa: E402
|
||||
|
||||
# Artifact filenames are original_n1_7_<embodiment_tag>.npz
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
# GR1 humanoid dimensions (from pretrained model metadata)
|
||||
# The actual GR1 robot has 44 dimensions for both state and action
|
||||
# GR00TTransform will pad state to 64 and truncate action to 32
|
||||
DUMMY_STATE_DIM = 44
|
||||
DUMMY_ACTION_DIM = 44
|
||||
DUMMY_ACTION_HORIZON = 16
|
||||
IMAGE_SIZE = 256
|
||||
DEVICE = "cpu"
|
||||
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
|
||||
|
||||
GR1_BODY_PARTS = {
|
||||
"left_arm": 7,
|
||||
"left_hand": 6,
|
||||
"left_leg": 6,
|
||||
"neck": 3,
|
||||
"right_arm": 7,
|
||||
"right_hand": 6,
|
||||
"right_leg": 6,
|
||||
"waist": 3,
|
||||
}
|
||||
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
"""Directory holding the per-embodiment .npz artifacts.
|
||||
|
||||
Self-contained by default: a sibling ``artifacts/`` directory next to this test.
|
||||
Override with ``GROOT_N1_7_PARITY_DIR`` (e.g. to point at a scratch location).
|
||||
The directory is read-only here -- it is populated by ``utils/dump_original_n1_7.py``
|
||||
run in the original gr00t environment; the test never creates it.
|
||||
"""
|
||||
env = os.environ.get("GROOT_N1_7_PARITY_DIR")
|
||||
if env:
|
||||
return Path(env)
|
||||
return Path(__file__).resolve().parent / "artifacts"
|
||||
|
||||
|
||||
def _discover_artifacts() -> list[tuple[str, Path]]:
|
||||
"""Return [(embodiment_tag, npz_path), ...] for every dumped artifact."""
|
||||
d = _artifact_dir()
|
||||
if not d.is_dir():
|
||||
return []
|
||||
out = []
|
||||
for p in sorted(d.glob(f"{_ARTIFACT_PREFIX}*{_ARTIFACT_SUFFIX}")):
|
||||
tag = p.name[len(_ARTIFACT_PREFIX) : -len(_ARTIFACT_SUFFIX)]
|
||||
out.append((tag, p))
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_checkpoint() -> str:
|
||||
env = os.environ.get("GROOT_N1_7_LIBERO_CKPT")
|
||||
if env:
|
||||
if not Path(env).exists():
|
||||
pytest.skip(f"GROOT_N1_7_LIBERO_CKPT={env} does not exist")
|
||||
return env
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
root = snapshot_download(
|
||||
"nvidia/GR00T-N1.7-LIBERO",
|
||||
local_files_only=True,
|
||||
allow_patterns=["libero_10/*"],
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
pytest.skip(f"GR00T N1.7 LIBERO checkpoint not available locally: {exc}")
|
||||
ckpt = Path(root) / "libero_10"
|
||||
if not (ckpt / "config.json").exists():
|
||||
pytest.skip(f"GR00T N1.7 LIBERO checkpoint incomplete at {ckpt}")
|
||||
return str(ckpt)
|
||||
|
||||
|
||||
def _load_artifact(path: Path):
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
inputs = {}
|
||||
for key in data.files:
|
||||
if not key.startswith("in::"):
|
||||
continue
|
||||
name = key[4:]
|
||||
arr = data[key]
|
||||
t = torch.from_numpy(np.asarray(arr))
|
||||
declared = dtypes.get(key, "")
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
inputs[name] = t
|
||||
return original_action, inputs
|
||||
|
||||
|
||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
"""Rebuild the nested model-input dict from dot-prefixed flat keys."""
|
||||
nested: dict = {}
|
||||
for dotted, value in inputs.items():
|
||||
parts = dotted.split(".")
|
||||
cur = nested
|
||||
for p in parts[:-1]:
|
||||
cur = cur.setdefault(p, {})
|
||||
cur[parts[-1]] = value
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
ckpt = _resolve_checkpoint()
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
ckpt,
|
||||
tune_llm=False,
|
||||
tune_visual=False,
|
||||
tune_projector=False,
|
||||
tune_diffusion_model=False,
|
||||
tune_vlln=False,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
# fp32 + SDPA on both sides: bf16 + differing attention kernels otherwise introduce
|
||||
# ~1e-2 numerical noise unrelated to the implementations.
|
||||
model.compute_dtype = "float32"
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
model.to(device=DEVICE, dtype=torch.float32)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
"env:\n .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py "
|
||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
torch.manual_seed(SEED)
|
||||
def cleanup_memory():
|
||||
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||
print("\nCleaning up memory...")
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
with torch.inference_mode():
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
print("Memory cleanup complete.")
|
||||
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
||||
original_action = original_action[:, :t, :d]
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
max_diff = diff.max().item()
|
||||
print(
|
||||
f"\n[{embodiment_tag}] shapes lerobot={tuple(lerobot_action.shape)} "
|
||||
f"original={tuple(original_action.shape)} "
|
||||
f"max|diff|={max_diff:.6e} mean|diff|={diff.mean().item():.6e}"
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Set deterministic behavior
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
|
||||
def instantiate_lerobot_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
) -> tuple[
|
||||
GrootPolicy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
|
||||
if from_pretrained:
|
||||
policy = GrootPolicy.from_pretrained(
|
||||
pretrained_name_or_path=model_path,
|
||||
strict=False,
|
||||
)
|
||||
policy.config.embodiment_tag = "gr1"
|
||||
else:
|
||||
config = GrootConfig(
|
||||
base_model_path=model_path,
|
||||
n_action_steps=DUMMY_ACTION_HORIZON,
|
||||
chunk_size=DUMMY_ACTION_HORIZON,
|
||||
image_size=[IMAGE_SIZE, IMAGE_SIZE],
|
||||
device=DEVICE,
|
||||
embodiment_tag="gr1",
|
||||
)
|
||||
policy = GrootPolicy(config)
|
||||
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=policy.config,
|
||||
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
|
||||
)
|
||||
|
||||
assert torch.allclose(lerobot_action, original_action, atol=ATOL, rtol=RTOL), (
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
return (policy, preprocessor, postprocessor)
|
||||
|
||||
|
||||
def instantiate_original_groot(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH,
|
||||
):
|
||||
"""Instantiate original Groot policy from NVIDIA's implementation."""
|
||||
from gr00t.data.transform.concat import ConcatTransform
|
||||
from gr00t.data.transform.state_action import StateActionToTensor
|
||||
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
|
||||
from gr00t.model.transforms import GR00TTransform
|
||||
|
||||
video_keys = ["video.ego_view"]
|
||||
state_keys = [
|
||||
"state"
|
||||
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
|
||||
action_keys = [
|
||||
"action.left_arm",
|
||||
"action.right_arm",
|
||||
"action.left_hand",
|
||||
"action.right_hand",
|
||||
"action.left_leg",
|
||||
"action.right_leg",
|
||||
"action.neck",
|
||||
"action.waist",
|
||||
]
|
||||
language_keys = ["annotation.human.action.task_description"]
|
||||
|
||||
modality_config = {
|
||||
"video": ModalityConfig(
|
||||
delta_indices=[0], # Current frame only
|
||||
modality_keys=video_keys,
|
||||
),
|
||||
"state": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=state_keys,
|
||||
),
|
||||
"action": ModalityConfig(
|
||||
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
|
||||
modality_keys=action_keys,
|
||||
),
|
||||
"language": ModalityConfig(
|
||||
delta_indices=[0],
|
||||
modality_keys=language_keys,
|
||||
),
|
||||
}
|
||||
|
||||
modality_transform = ComposedModalityTransform(
|
||||
transforms=[
|
||||
VideoToTensor(apply_to=video_keys),
|
||||
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
|
||||
# State is already a single concatenated key, so no StateActionToTensor needed
|
||||
# Convert action from numpy to tensor
|
||||
StateActionToTensor(apply_to=action_keys),
|
||||
# Concatenate only video and actions (state is already single key)
|
||||
ConcatTransform(
|
||||
video_concat_order=video_keys,
|
||||
state_concat_order=[], # Empty:state is already single key
|
||||
action_concat_order=action_keys,
|
||||
),
|
||||
GR00TTransform(
|
||||
max_state_dim=64,
|
||||
max_action_dim=32,
|
||||
state_horizon=1,
|
||||
action_horizon=DUMMY_ACTION_HORIZON,
|
||||
training=False,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy = Gr00tPolicy(
|
||||
model_path=model_path,
|
||||
embodiment_tag=EmbodimentTag.GR1,
|
||||
modality_config=modality_config,
|
||||
modality_transform=modality_transform,
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
return policy, modality_config, modality_transform
|
||||
|
||||
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red cube and place it in the bin"
|
||||
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
|
||||
|
||||
batch = {
|
||||
"observation.state": state,
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=device, # Action ground truth (for training)
|
||||
),
|
||||
"observation.images.ego_view": torch.rand(
|
||||
batch_size,
|
||||
3,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_SIZE,
|
||||
dtype=torch.float32,
|
||||
device=device, # Images in [0, 1] range as expected by LeRobot
|
||||
),
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def convert_lerobot_to_original_format(batch, modality_config):
|
||||
"""Convert LeRobot batch format to original Groot format.
|
||||
|
||||
The original Groot expects observations in this format:
|
||||
{
|
||||
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
|
||||
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
|
||||
"annotation.<annotation_type>": str or list[str]
|
||||
}
|
||||
"""
|
||||
# Original Groot expects (T, H, W, C) format for images
|
||||
# LeRobot has (B, C, H, W) format, so we need to convert
|
||||
observation = {}
|
||||
|
||||
for img_key in ["ego_view"]:
|
||||
lerobot_key = f"observation.images.{img_key}"
|
||||
if lerobot_key in batch:
|
||||
img = batch[lerobot_key]
|
||||
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
|
||||
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
|
||||
# Convert [0, 1] to [0, 255] uint8 as expected by original
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
observation[f"video.{img_key}"] = img_np
|
||||
|
||||
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
|
||||
if "observation.state" in batch:
|
||||
state = batch["observation.state"]
|
||||
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
|
||||
observation["state"] = state_np
|
||||
|
||||
if "action" in batch:
|
||||
action = batch["action"]
|
||||
action_np = action.cpu().numpy()
|
||||
|
||||
start_idx = 0
|
||||
for part_name, part_dim in GR1_BODY_PARTS.items():
|
||||
end_idx = start_idx + part_dim
|
||||
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
|
||||
if "task" in batch:
|
||||
task_list = batch["task"]
|
||||
# GR00TTransform expects language with (B, T) shape for batched data
|
||||
# Create a (B, T=1) array where each element is the string directly
|
||||
bsz = len(task_list)
|
||||
task_array = np.empty((bsz, 1), dtype=object)
|
||||
for i in range(bsz):
|
||||
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
|
||||
observation["annotation.human.action.task_description"] = task_array
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def test_groot_original_vs_lerobot_pretrained():
|
||||
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
|
||||
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
print("\n[LeRobot] Running inference...")
|
||||
lerobot_policy.eval()
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
|
||||
|
||||
print("\n[Original] Running inference...")
|
||||
original_policy.model.eval()
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
original_obs_transformed = modality_transform(deepcopy(observation))
|
||||
|
||||
# Important: Reset seed immediately before inference to ensure identical RNG state
|
||||
torch.manual_seed(42)
|
||||
|
||||
with torch.no_grad():
|
||||
original_model_output = original_policy.model.get_action(original_obs_transformed)
|
||||
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
|
||||
# Take first timestep
|
||||
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
|
||||
|
||||
print("Action Comparison:")
|
||||
diff = lerobot_actions - original_actions
|
||||
abs_diff = torch.abs(diff)
|
||||
|
||||
for batch_idx in range(lerobot_actions.shape[0]):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Batch {batch_idx}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
|
||||
print("-" * 60)
|
||||
for action_idx in range(lerobot_actions.shape[1]):
|
||||
lr_val = lerobot_actions[batch_idx, action_idx].item()
|
||||
orig_val = original_actions[batch_idx, action_idx].item()
|
||||
diff_val = abs(lr_val - orig_val)
|
||||
sign = "+" if (lr_val - orig_val) > 0 else "-"
|
||||
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
|
||||
|
||||
max_diff = abs_diff.max().item()
|
||||
tolerance = 0.001
|
||||
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
|
||||
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
|
||||
)
|
||||
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation
|
||||
cleanup_memory()
|
||||
|
||||
|
||||
def test_groot_forward_pass_comparison():
|
||||
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
|
||||
print("Test: Forward Pass Comparison (Training Mode)")
|
||||
|
||||
set_seed_all(42)
|
||||
|
||||
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
|
||||
from_pretrained=True
|
||||
)
|
||||
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
|
||||
|
||||
batch = create_dummy_data()
|
||||
lerobot_policy.eval()
|
||||
original_policy.model.eval()
|
||||
|
||||
print("\n[LeRobot] Running forward pass...")
|
||||
batch_lerobot = deepcopy(batch)
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
print(f" Loss: {lerobot_loss.item():.6f}")
|
||||
|
||||
print("\n[Original] Running forward pass...")
|
||||
observation = convert_lerobot_to_original_format(batch, modality_config)
|
||||
transformed_obs = modality_transform(observation)
|
||||
|
||||
if "action" not in transformed_obs:
|
||||
action_for_forward = batch_lerobot_processed["action"]
|
||||
action_mask_for_forward = batch_lerobot_processed["action_mask"]
|
||||
|
||||
# Match action horizon if needed
|
||||
if action_for_forward.shape[1] != original_policy.model.action_horizon:
|
||||
if action_for_forward.shape[1] < original_policy.model.action_horizon:
|
||||
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
|
||||
last_action = action_for_forward[:, -1:, :]
|
||||
padding = last_action.repeat(1, pad_size, 1)
|
||||
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
|
||||
|
||||
mask_padding = torch.zeros(
|
||||
action_mask_for_forward.shape[0],
|
||||
pad_size,
|
||||
action_mask_for_forward.shape[2],
|
||||
dtype=action_mask_for_forward.dtype,
|
||||
device=action_mask_for_forward.device,
|
||||
)
|
||||
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
|
||||
else:
|
||||
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
|
||||
action_mask_for_forward = action_mask_for_forward[
|
||||
:, : original_policy.model.action_horizon, :
|
||||
]
|
||||
|
||||
transformed_obs["action"] = action_for_forward
|
||||
transformed_obs["action_mask"] = action_mask_for_forward
|
||||
|
||||
set_seed_all(42)
|
||||
with torch.no_grad():
|
||||
original_outputs = original_policy.model.forward(transformed_obs)
|
||||
|
||||
original_loss = original_outputs["loss"]
|
||||
print(f" Loss: {original_loss.item():.6f}")
|
||||
|
||||
loss_diff = abs(lerobot_loss.item() - original_loss.item())
|
||||
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
|
||||
|
||||
print("\nLoss Values:")
|
||||
print(f" LeRobot: {lerobot_loss.item():.6f}")
|
||||
print(f" Original: {original_loss.item():.6f}")
|
||||
print(f" Absolute difference: {loss_diff:.6f}")
|
||||
print(f" Relative difference: {loss_rel_diff:.2f}%")
|
||||
|
||||
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
|
||||
del original_policy, modality_config, modality_transform
|
||||
del batch, batch_lerobot, observation, transformed_obs
|
||||
cleanup_memory()
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Utilities shared by GR00T policy tests."""
|
||||
@@ -1,198 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
"""Producer (run in the ORIGINAL gr00t env): dump original GR00T N1.7 outputs + inputs.
|
||||
|
||||
The original NVIDIA ``gr00t`` package pins ``transformers==4.57.3`` (py3.10) and its
|
||||
model-config dataclasses are incompatible with the ``transformers==5.x`` that the
|
||||
LeRobot GR00T N1.7 integration requires. The two implementations therefore cannot be
|
||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||
|
||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||
* the random seed used right before the flow-matching sampler,
|
||||
* the raw ``action_pred`` tensor returned by ``model.get_action`` (normalized space,
|
||||
before any per-implementation action decoding).
|
||||
|
||||
Inputs are built GENERICALLY from the checkpoint metadata (no per-tag hardcoding):
|
||||
state keys + dims come from ``statistics.json``; video + language keys come from the
|
||||
processor's per-embodiment modality configs. This lets us test many embodiment tags
|
||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||
``libero_sim``.
|
||||
|
||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
||||
|
||||
Usage:
|
||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt <path-to-GR00T-N1.7-LIBERO/libero_10> \
|
||||
--out-dir tests/policies/groot/artifacts \
|
||||
[--tags libero_sim,oxe_droid_relative_eef_relative_joint,...] \
|
||||
[--device cuda] [--seed 42]
|
||||
|
||||
If --tags is omitted, every embodiment present in the checkpoint statistics is dumped.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
IMAGE_SIZE = 256
|
||||
BATCH_SIZE = 2
|
||||
PROMPT = "pick up the black bowl and place it on the plate"
|
||||
|
||||
|
||||
def load_statistics(ckpt: str) -> dict:
|
||||
with open(os.path.join(ckpt, "statistics.json")) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
"""Build a dummy observation dict generically from the embodiment metadata."""
|
||||
rng = np.random.default_rng(seed)
|
||||
video = {
|
||||
k: rng.integers(0, 256, (BATCH_SIZE, 1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
|
||||
for k in video_keys
|
||||
}
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
|
||||
def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_path):
|
||||
from gr00t.data.types import MessageType
|
||||
|
||||
video_keys = modality_cfg["video"].modality_keys
|
||||
lang_key = modality_cfg["language"].modality_keys[0]
|
||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||
|
||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||
policy.modality_configs = {
|
||||
k: v for k, v in policy.processor.get_modality_configs()[tag].items() if k != "rl_info"
|
||||
}
|
||||
policy.language_key = policy.modality_configs["language"].modality_keys[0]
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
unbatched = policy._unbatch_observation(observation)
|
||||
processed = []
|
||||
for obs in unbatched:
|
||||
vla = policy._to_vla_step_data(obs)
|
||||
processed.append(policy.processor([{"type": MessageType.EPISODE_STEP.value, "content": vla}]))
|
||||
collated = policy.collate_fn(processed)
|
||||
|
||||
def to_dev(x):
|
||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||
return x.to(args.device, torch.float32)
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.to(args.device)
|
||||
if isinstance(x, dict):
|
||||
return {k: to_dev(v) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
collated = {k: to_dev(v) for k, v in collated.items()}
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
with torch.inference_mode():
|
||||
out = fair_model.get_action(**collated)
|
||||
action_pred = out["action_pred"].float().cpu().numpy()
|
||||
|
||||
flat, meta = {}, {}
|
||||
|
||||
def flatten(prefix, obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
arr = obj.float().cpu().numpy() if torch.is_floating_point(obj) else obj.cpu().numpy()
|
||||
flat[f"in::{prefix}"] = arr
|
||||
meta[f"in::{prefix}"] = str(obj.dtype)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
flatten(f"{prefix}.{k}" if prefix else k, v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
flat[f"in::{prefix}"] = np.array(obj, dtype=object)
|
||||
else:
|
||||
flat[f"in::{prefix}"] = np.array(obj)
|
||||
|
||||
flatten("", collated)
|
||||
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(
|
||||
out_path,
|
||||
action_pred=action_pred,
|
||||
seed=np.array(args.seed),
|
||||
device=np.array(args.device),
|
||||
embodiment_tag=np.array(tag),
|
||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||
**flat,
|
||||
)
|
||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--ckpt", required=True)
|
||||
ap.add_argument("--out-dir", required=True, help="directory for per-tag .npz files")
|
||||
ap.add_argument("--tags", default="", help="comma-separated embodiment tags (default: all in stats)")
|
||||
ap.add_argument("--device", default="cuda")
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
args = ap.parse_args()
|
||||
|
||||
from gr00t.policy.gr00t_policy import Gr00tPolicy
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
stats = load_statistics(args.ckpt)
|
||||
requested = [t.strip() for t in args.tags.split(",") if t.strip()] or list(stats.keys())
|
||||
|
||||
# Load the policy once (for its processor/preprocessing) on any valid tag.
|
||||
bootstrap_tag = "libero_sim" if "libero_sim" in stats else requested[0]
|
||||
policy = Gr00tPolicy(embodiment_tag=bootstrap_tag, model_path=args.ckpt, device=args.device)
|
||||
all_modality = policy.processor.get_modality_configs()
|
||||
|
||||
# Load a FAIR model (SDPA + fp32) once and reuse across tags. Otherwise the
|
||||
# original checkpoint default (flash_attention_2 + bf16) introduces kernel/rounding
|
||||
# noise vs the LeRobot env (which has no flash_attn and runs SDPA).
|
||||
cfg = AutoConfig.from_pretrained(args.ckpt, trust_remote_code=True)
|
||||
cfg.use_flash_attention = False
|
||||
cfg.load_bf16 = False
|
||||
fair_model = AutoModel.from_pretrained(args.ckpt, config=cfg, trust_remote_code=True)
|
||||
fair_model.to(device=args.device, dtype=torch.float32)
|
||||
fair_model.eval()
|
||||
|
||||
out_dir = Path(args.out_dir)
|
||||
done, skipped = [], []
|
||||
for tag in requested:
|
||||
if tag not in stats or tag not in all_modality:
|
||||
print(f"[skip] {tag}: not present in checkpoint statistics/modality configs")
|
||||
skipped.append(tag)
|
||||
continue
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[fail] {tag}: {type(exc).__name__}: {exc}")
|
||||
skipped.append(tag)
|
||||
|
||||
print(f"\nDumped {len(done)} tags: {done}")
|
||||
if skipped:
|
||||
print(f"Skipped/failed {len(skipped)} tags: {skipped}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,273 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BATCH_SIZE = 2
|
||||
ACTION_DIM = 3
|
||||
STATE_DIM = 4
|
||||
IMAGE_SIZE = 8
|
||||
ACTION_HORIZON = 4
|
||||
N_ACTION_STEPS = 2
|
||||
NUM_VIDEO_FRAMES = 3
|
||||
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||
|
||||
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def set_seed_all(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def make_config(
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> VLAJEPAConfig:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
device="cpu",
|
||||
chunk_size=action_horizon,
|
||||
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||
action_dim=action_dim,
|
||||
state_dim=state_dim,
|
||||
num_video_frames=num_video_frames,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=1,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def make_train_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
action_dim: int = ACTION_DIM,
|
||||
state_dim: int = STATE_DIM,
|
||||
action_horizon: int = ACTION_HORIZON,
|
||||
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
def make_inference_batch(
|
||||
batch_size: int = BATCH_SIZE,
|
||||
state_dim: int = STATE_DIM,
|
||||
) -> dict[str, Tensor | list[str]]:
|
||||
return {
|
||||
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
"task": ["pick up the cube"] * batch_size,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeLanguageLayer(nn.Module):
|
||||
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
|
||||
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
|
||||
return (hidden,)
|
||||
|
||||
|
||||
class _FakeLanguageModel(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self._hidden_size = hidden_size
|
||||
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
|
||||
self.layers[-1](hidden)
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
class _FakeQwenInnerModel(nn.Module):
|
||||
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.language_model = _FakeLanguageModel(hidden_size)
|
||||
|
||||
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
|
||||
class _FakeQwenBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
self.config = SimpleNamespace(
|
||||
hidden_size=hidden_size,
|
||||
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||
)
|
||||
self.model = _FakeQwenInnerModel(hidden_size)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self.config.hidden_size
|
||||
values = torch.arange(
|
||||
batch_size * seq_len * hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=torch.float32,
|
||||
).view(batch_size, seq_len, hidden_size)
|
||||
hidden = values / values.numel() + self.weight
|
||||
self.model(input_ids) # call through so the forward hook on layers[-1] fires
|
||||
return SimpleNamespace(hidden_states=[hidden])
|
||||
|
||||
|
||||
class _FakeQwenInterface(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||
return action_tokens, action_token_ids, 2000
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: list[list[Image.Image]],
|
||||
instructions: list[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, Tensor]:
|
||||
batch_size = len(images)
|
||||
del images, instructions, action_prompt, embodied_prompt
|
||||
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
token_ids = (
|
||||
[10]
|
||||
+ list(range(1000, 1000 + action_count))
|
||||
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||
+ [11]
|
||||
)
|
||||
return {
|
||||
"input_ids": torch.tensor(
|
||||
[token_ids] * batch_size,
|
||||
device=self.model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
class _FakeVideoEncoder(nn.Module):
|
||||
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(1))
|
||||
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
|
||||
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
|
||||
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.weight.device
|
||||
|
||||
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||
hidden_size = self.config.hidden_size
|
||||
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||
|
||||
|
||||
class _FakeVideoProcessor:
|
||||
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||
assert return_tensors == "pt"
|
||||
if isinstance(videos, list):
|
||||
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||
else:
|
||||
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||
return {"pixel_values_videos": pixel_values}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||
|
||||
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoModel,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_vla_jepa.AutoVideoProcessor,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
from conftest import (
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
set_seed_all,
|
||||
) # noqa: E402
|
||||
|
||||
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||
VLAJEPAActionHead,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLAJEPAActionHead
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4), # default test dims
|
||||
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||
(6, 8, 8), # medium dims
|
||||
],
|
||||
)
|
||||
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||
assert t.shape == (200,)
|
||||
assert torch.isfinite(t).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
timesteps = torch.randint(0, 100, (2,))
|
||||
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||
|
||||
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||
if state_dim > 0:
|
||||
assert out_with.shape[1] > out_none.shape[1]
|
||||
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(2, action_horizon, action_dim)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
def test_action_head_forward_gradient_flows() -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
loss = head.forward(conditioning, actions, state)
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||
pred = head.predict_action(conditioning, state)
|
||||
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||
assert torch.isfinite(pred).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# action_is_pad masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||
assert loss.item() == 0.0
|
||||
|
||||
|
||||
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||
set_seed_all(42)
|
||||
config = make_config()
|
||||
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||
|
||||
set_seed_all(0)
|
||||
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||
|
||||
set_seed_all(0)
|
||||
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||
|
||||
assert torch.isclose(loss_none, loss_zeros)
|
||||
@@ -1,57 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def test_delta_indices() -> None:
|
||||
config = make_config()
|
||||
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||
|
||||
|
||||
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||
with pytest.raises(ValueError, match="n_action_steps"):
|
||||
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
|
||||
|
||||
|
||||
def test_too_few_video_frames_raises() -> None:
|
||||
with pytest.raises(ValueError, match="video_horizon"):
|
||||
VLAJEPAConfig(
|
||||
chunk_size=16,
|
||||
n_action_steps=16,
|
||||
num_video_frames=2,
|
||||
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||
)
|
||||
|
||||
|
||||
def test_validate_features_no_image_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_no_action_raises() -> None:
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||
},
|
||||
output_features={},
|
||||
)
|
||||
with pytest.raises(ValueError, match="action output feature"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||
config = make_config(action_dim=6, state_dim=10)
|
||||
assert config.action_dim == 6
|
||||
assert config.state_dim == 10
|
||||
@@ -1,598 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||
)
|
||||
|
||||
from conftest import ( # noqa: E402
|
||||
ACTION_DIM,
|
||||
ACTION_HORIZON,
|
||||
BATCH_SIZE,
|
||||
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||
EXPECTED_SELECT_ACTION_SHAPE,
|
||||
IMAGE_SIZE,
|
||||
N_ACTION_STEPS,
|
||||
QWEN_HIDDEN_SIZE,
|
||||
STATE_DIM,
|
||||
make_config,
|
||||
make_inference_batch,
|
||||
make_train_batch,
|
||||
set_seed_all,
|
||||
)
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||
from lerobot.utils.constants import ACTION # noqa: E402
|
||||
|
||||
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||
|
||||
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core training / inference tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
|
||||
batch = make_train_batch()
|
||||
batch_before = deepcopy(batch)
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
assert logs["action_loss"] > 0
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
loss.backward()
|
||||
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||
# Batch must not be mutated.
|
||||
assert set(batch) == set(batch_before)
|
||||
for key, value in batch.items():
|
||||
if isinstance(value, Tensor):
|
||||
assert torch.equal(value, batch_before[key])
|
||||
else:
|
||||
assert value == batch_before[key]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.train()
|
||||
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_dim,state_dim,action_horizon",
|
||||
[
|
||||
(3, 4, 4),
|
||||
(7, 0, 16),
|
||||
(6, 8, 8),
|
||||
],
|
||||
)
|
||||
def test_training_forward_various_dims(
|
||||
patch_vla_jepa_external_models: None,
|
||||
action_dim: int,
|
||||
state_dim: int,
|
||||
action_horizon: int,
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.train()
|
||||
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||
loss, _ = policy.forward(batch)
|
||||
assert torch.isfinite(loss) and loss > 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert chunk.device.type == "cpu"
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
a1 = policy.select_action(batch)
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||
def test_action_generation_various_dims(
|
||||
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||
) -> None:
|
||||
set_seed_all(42)
|
||||
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
policy.eval()
|
||||
batch = make_inference_batch(state_dim=state_dim)
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
assert chunk.shape[-1] == action_dim
|
||||
assert torch.isfinite(chunk).all()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
set_seed_all(123)
|
||||
actions_1 = policy.predict_action_chunk(batch)
|
||||
set_seed_all(123)
|
||||
actions_2 = policy.predict_action_chunk(batch)
|
||||
|
||||
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
for seed in [0, 42, 123]:
|
||||
set_seed_all(seed)
|
||||
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action queue behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
batch = make_inference_batch()
|
||||
|
||||
# First call fills the queue (n_action_steps items) and pops one.
|
||||
a1 = policy.select_action(batch)
|
||||
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||
|
||||
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||
a2 = policy.select_action(batch)
|
||||
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
policy.eval()
|
||||
policy.select_action(make_inference_batch())
|
||||
assert len(policy._queues[ACTION]) > 0
|
||||
|
||||
policy.reset()
|
||||
assert len(policy._queues[ACTION]) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Format conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||
from PIL import Image
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
examples = policy._prepare_model_inputs(make_train_batch())
|
||||
|
||||
assert len(examples) == BATCH_SIZE
|
||||
for ex in examples:
|
||||
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||
assert ex["state"].shape == (1, STATE_DIM)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||
assert "action" not in ex
|
||||
assert "image" in ex and "video" in ex and "lang" in ex
|
||||
|
||||
|
||||
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch["task"]
|
||||
examples = policy._prepare_model_inputs(batch)
|
||||
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||
|
||||
|
||||
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
batch["task"] = "open the drawer"
|
||||
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
policy = VLAJEPAPolicy(make_config())
|
||||
batch = make_inference_batch()
|
||||
del batch[OBS_STATE]
|
||||
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pretrained checkpoint
|
||||
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||
return batch
|
||||
|
||||
|
||||
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||
cfg = policy.config
|
||||
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||
for key, feat in cfg.image_features.items():
|
||||
h, w = feat.shape[-2], feat.shape[-1]
|
||||
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||
if cfg.robot_state_feature is not None:
|
||||
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||
return batch
|
||||
|
||||
|
||||
_CP_ROOT = "lerobot"
|
||||
|
||||
# Each tuple: (repo_id, enable_world_model)
|
||||
_HUB_VARIANTS = [
|
||||
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||
]
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
assert policy.config.enable_world_model == enable_world_model
|
||||
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||
|
||||
|
||||
@extended_test
|
||||
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||
policy.train()
|
||||
|
||||
batch = _make_hub_train_batch(policy)
|
||||
loss, logs = policy.forward(batch)
|
||||
assert torch.isfinite(loss)
|
||||
assert "action_loss" in logs
|
||||
if enable_world_model:
|
||||
assert "wm_loss" in logs
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
qwen_params = list(policy.model.qwen.parameters())
|
||||
assert all(not p.requires_grad for p in qwen_params)
|
||||
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||
assert not policy.config.enable_world_model
|
||||
assert policy.model.video_predictor is None
|
||||
assert policy.model.video_encoder is None
|
||||
|
||||
|
||||
@extended_test
|
||||
def test_hub_libero_inference_shape() -> None:
|
||||
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||
policy.eval()
|
||||
batch = _make_hub_inference_batch(policy)
|
||||
action = policy.select_action(batch)
|
||||
assert action.shape[-1] == policy.config.action_dim
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Postprocessor unnormalization tests
|
||||
#
|
||||
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
|
||||
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
|
||||
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
return {
|
||||
ACTION: {
|
||||
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
|
||||
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
|
||||
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
|
||||
rng = np.random.default_rng(7)
|
||||
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
actions_tensor = torch.from_numpy(actions_np)
|
||||
transition = policy_action_to_transition(actions_tensor)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||
from lerobot.processor import UnnormalizerProcessorStep
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
|
||||
# Deliberately out-of-range inputs
|
||||
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
|
||||
clipped = np.clip(actions_np, -1.0, 1.0)
|
||||
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
clip_step = ClipActionsProcessorStep()
|
||||
unnorm_step = UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||
stats=dataset_stats,
|
||||
)
|
||||
|
||||
transition = policy_action_to_transition(torch.from_numpy(actions_np))
|
||||
transition = clip_step(transition)
|
||||
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_postprocessor_applied_after_predict_action_chunk(
|
||||
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
|
||||
|
||||
Verifies the split: predict_action_chunk returns normalized actions, and calling the
|
||||
postprocessor on them produces the correctly unnormalized result.
|
||||
"""
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||
|
||||
cfg = make_config()
|
||||
cfg.clip_normalized_actions = False
|
||||
cfg.binarize_gripper_action = False
|
||||
policy = VLAJEPAPolicy(cfg)
|
||||
policy.eval()
|
||||
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||
|
||||
dataset_stats = _make_dataset_stats()
|
||||
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||
|
||||
batch = make_inference_batch()
|
||||
chunk = policy.predict_action_chunk(batch)
|
||||
|
||||
# predict_action_chunk returns raw (normalized) actions
|
||||
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
|
||||
"predict_action_chunk should return raw actions without unnormalization applied."
|
||||
)
|
||||
|
||||
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||
unnormed = postprocessor(chunk)
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# World-model view adjustment (padding / trimming) tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
|
||||
|
||||
|
||||
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
config = VLAJEPAConfig(
|
||||
input_features={
|
||||
**{
|
||||
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
)
|
||||
for i in range(num_views)
|
||||
},
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||
device="cpu",
|
||||
chunk_size=ACTION_HORIZON,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
action_dim=ACTION_DIM,
|
||||
state_dim=STATE_DIM,
|
||||
num_video_frames=_MULTIVIEW_NUM_FRAMES,
|
||||
num_action_tokens_per_timestep=2,
|
||||
num_embodied_action_tokens_per_instruction=3,
|
||||
num_inference_timesteps=2,
|
||||
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||
action_model_type="DiT-test",
|
||||
action_num_layers=1,
|
||||
predictor_depth=1,
|
||||
predictor_num_heads=2,
|
||||
predictor_mlp_ratio=2.0,
|
||||
jepa_tubelet_size=jepa_tubelet_size,
|
||||
)
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
for i in range(num_views)
|
||||
}
|
||||
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
|
||||
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
|
||||
batch["task"] = ["pick up the cube"] * batch_size
|
||||
return batch
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_views",
|
||||
[
|
||||
1, # fewer views than jepa_tubelet_size → first view duplicated
|
||||
2, # exact match → unchanged
|
||||
3, # more views than jepa_tubelet_size → trimmed to first two
|
||||
],
|
||||
)
|
||||
def test_training_forward_world_model_view_adjustment(
|
||||
patch_vla_jepa_external_models: None,
|
||||
num_views: int,
|
||||
) -> None:
|
||||
"""World-model view padding/trimming must not break the training forward pass."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
|
||||
assert torch.isfinite(loss)
|
||||
assert logs["wm_loss"] >= 0
|
||||
|
||||
|
||||
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=1))
|
||||
|
||||
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
for i in range(BATCH_SIZE):
|
||||
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
|
||||
|
||||
|
||||
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
|
||||
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
|
||||
set_seed_all(42)
|
||||
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
|
||||
policy.train()
|
||||
|
||||
captured_videos: list = []
|
||||
original_processor = policy.model.video_processor
|
||||
|
||||
class _CapturingProcessor:
|
||||
def __call__(self, videos: list, return_tensors: str) -> dict:
|
||||
captured_videos.extend(videos)
|
||||
return original_processor(videos=videos, return_tensors=return_tensors)
|
||||
|
||||
policy.model.video_processor = _CapturingProcessor()
|
||||
policy.forward(_make_multiview_train_batch(num_views=3))
|
||||
|
||||
# Only B*2 items must reach the encoder, not B*3.
|
||||
assert len(captured_videos) == BATCH_SIZE * 2
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.world_model import (
|
||||
ActionConditionedVideoPredictor,
|
||||
)
|
||||
|
||||
_ACTION_EMBED_DIM = 8
|
||||
|
||||
|
||||
def _make_predictor(
|
||||
embed_dim: int = 8,
|
||||
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||
predictor_embed_dim: int = 24,
|
||||
num_action_tokens: int = 2,
|
||||
tokens_per_frame: int = 1,
|
||||
) -> ActionConditionedVideoPredictor:
|
||||
return ActionConditionedVideoPredictor(
|
||||
num_frames=1,
|
||||
img_size=(1, tokens_per_frame),
|
||||
patch_size=1,
|
||||
tubelet_size=1,
|
||||
embed_dim=embed_dim,
|
||||
action_embed_dim=action_embed_dim,
|
||||
predictor_embed_dim=predictor_embed_dim,
|
||||
depth=1,
|
||||
num_heads=2,
|
||||
mlp_ratio=2.0,
|
||||
num_action_tokens_per_step=num_action_tokens,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||
[
|
||||
(1, 2, 1, 8),
|
||||
(2, 3, 4, 8),
|
||||
(4, 5, 2, 16),
|
||||
],
|
||||
)
|
||||
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||
predictor = _make_predictor(
|
||||
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||
)
|
||||
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||
out = predictor(frame_tokens, action_tokens)
|
||||
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
|
||||
assert torch.isfinite(out).all()
|
||||
|
||||
|
||||
def test_predictor_step_mismatch_raises() -> None:
|
||||
predictor = _make_predictor(tokens_per_frame=4)
|
||||
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
|
||||
with pytest.raises(RuntimeError):
|
||||
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch
|
||||
@@ -1,340 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Robometer reward model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config
|
||||
from lerobot.rewards.robometer import RobometerConfig
|
||||
from lerobot.rewards.robometer.configuration_robometer import ROBOMETER_SPECIAL_TOKENS
|
||||
from lerobot.rewards.robometer.modeling_robometer import (
|
||||
ROBOMETER_FEATURE_PREFIX,
|
||||
convert_bins_to_continuous,
|
||||
decode_progress_outputs,
|
||||
)
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
# Length of the fake tokenizer used in `_patch_build`. The deterministic
|
||||
# resize target derived in ``RobometerConfig.__post_init__`` is therefore
|
||||
# ``_FAKE_TOKENIZER_LEN + len(ROBOMETER_SPECIAL_TOKENS)``.
|
||||
_FAKE_TOKENIZER_LEN = 100
|
||||
_EXPECTED_RESIZED_VOCAB = _FAKE_TOKENIZER_LEN + len(ROBOMETER_SPECIAL_TOKENS)
|
||||
|
||||
|
||||
class _FakeQwenConfig:
|
||||
"""Stand-in for a Qwen3-VL config (the `model.config` attribute).
|
||||
|
||||
``to_dict`` matches HF's ``PretrainedConfig.to_dict`` closely enough for
|
||||
``RobometerConfig.__post_init__`` to snapshot a meaningful ``vlm_config``
|
||||
into the saved ``config.json`` and for the reload path to round-trip
|
||||
through ``AutoConfig.for_model``.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 8, vocab_size: int = _FAKE_TOKENIZER_LEN) -> None:
|
||||
# `vocab_size` here is the *pre-resize* value the fake backbone advertises.
|
||||
# `__post_init__` is expected to overwrite it with `len(tokenizer) + 5`.
|
||||
self.text_config = SimpleNamespace(hidden_size=hidden_dim, vocab_size=vocab_size)
|
||||
self._hidden_dim = hidden_dim
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model_type": "fake_qwen",
|
||||
"text_config": {
|
||||
"hidden_size": self._hidden_dim,
|
||||
"vocab_size": self._vocab_size,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class _FakeEmbeddings(torch.nn.Module):
|
||||
def __init__(self, num_embeddings: int = _FAKE_TOKENIZER_LEN) -> None:
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
|
||||
|
||||
class _FakeBaseModel(torch.nn.Module):
|
||||
"""Stand-in for the Qwen3-VL backbone during tests.
|
||||
|
||||
Provides the minimum surface `RobometerRewardModel.__init__` and
|
||||
`_compute_rbm_logits` rely on: a `parameters()` iterator (for dtype +
|
||||
device), a `config.text_config.hidden_size`, a `config.to_dict()` so
|
||||
`_save_pretrained` can snapshot `vlm_config`,
|
||||
`get_input_embeddings()` / `resize_token_embeddings()` so the fresh-init
|
||||
embed resize is a no-op, and a forward that returns a `SimpleNamespace`
|
||||
with a `hidden_states` tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 8) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.nn.Parameter(torch.zeros(1))
|
||||
self.hidden_dim = hidden_dim
|
||||
self.config = _FakeQwenConfig(hidden_dim)
|
||||
self._embeddings = _FakeEmbeddings()
|
||||
|
||||
def get_input_embeddings(self) -> _FakeEmbeddings:
|
||||
return self._embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_size: int) -> None:
|
||||
self._embeddings.num_embeddings = new_size
|
||||
|
||||
def forward(self, **kwargs): # noqa: ARG002 - intentional kwargs sink
|
||||
input_ids = kwargs["input_ids"]
|
||||
return SimpleNamespace(
|
||||
hidden_states=(torch.zeros(input_ids.shape[0], input_ids.shape[1], self.hidden_dim),),
|
||||
last_hidden_state=torch.zeros(input_ids.shape[0], input_ids.shape[1], self.hidden_dim),
|
||||
)
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
"""Minimal stand-in for an HF tokenizer.
|
||||
|
||||
``RobometerConfig.__post_init__`` uses ``len(tokenizer)`` to compute the
|
||||
deterministic resize target ``len(tokenizer) + len(ROBOMETER_SPECIAL_TOKENS)``,
|
||||
so a working ``__len__`` is all we need.
|
||||
"""
|
||||
|
||||
def __init__(self, length: int = _FAKE_TOKENIZER_LEN) -> None:
|
||||
self._length = length
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
|
||||
|
||||
def _patch_build(monkeypatch) -> None:
|
||||
"""Stub out the HF AutoX calls so Robometer construction stays cheap in tests.
|
||||
|
||||
Covers (EO-1 style — no model-side override hooks):
|
||||
* ``AutoConfig.from_pretrained`` (config side) — used by
|
||||
``RobometerConfig.__post_init__`` to snapshot the backbone config.
|
||||
* ``AutoTokenizer.from_pretrained`` (config side) — used by
|
||||
``__post_init__`` to compute ``len(tokenizer) + 5``.
|
||||
* ``AutoConfig.for_model`` — used by
|
||||
``RobometerConfig.vlm_backbone_config`` when rebuilding for ``from_config``.
|
||||
* ``AutoModelForImageTextToText.from_pretrained`` — fresh-training path
|
||||
(``pretrained_path is None``).
|
||||
* ``AutoModelForImageTextToText.from_config`` — checkpoint-reload path
|
||||
(``pretrained_path`` is set).
|
||||
"""
|
||||
from lerobot.rewards.robometer import configuration_robometer, modeling_robometer
|
||||
|
||||
monkeypatch.setattr(
|
||||
modeling_robometer.AutoModelForImageTextToText,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeBaseModel(hidden_dim=8),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
modeling_robometer.AutoModelForImageTextToText,
|
||||
"from_config",
|
||||
lambda *args, **kwargs: _FakeBaseModel(hidden_dim=8),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
configuration_robometer.AutoConfig,
|
||||
"for_model",
|
||||
lambda *args, **kwargs: _FakeQwenConfig(hidden_dim=8),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
configuration_robometer.AutoConfig,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeQwenConfig(hidden_dim=8),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
configuration_robometer.AutoTokenizer,
|
||||
"from_pretrained",
|
||||
lambda *args, **kwargs: _FakeTokenizer(length=_FAKE_TOKENIZER_LEN),
|
||||
)
|
||||
|
||||
|
||||
def _make_batch(features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Build a `compute_reward`-ready batch using Robometer's namespaced keys."""
|
||||
return {f"{ROBOMETER_FEATURE_PREFIX}{key}": value for key, value in features.items()}
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_config_registered(monkeypatch):
|
||||
_patch_build(monkeypatch)
|
||||
assert "robometer" in RewardModelConfig.get_known_choices()
|
||||
assert RewardModelConfig.get_choice_class("robometer") is RobometerConfig
|
||||
assert isinstance(make_reward_model_config("robometer", device="cpu"), RobometerConfig)
|
||||
|
||||
|
||||
def test_robometer_factory_returns_in_tree_class():
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
assert get_reward_model_class("robometer") is RobometerRewardModel
|
||||
|
||||
|
||||
def test_convert_bins_to_continuous_returns_expected_values():
|
||||
# Two frames: first peaks at bin 0 (center 0.0), second peaks at bin 9 (center 1.0).
|
||||
bin_logits = torch.full((2, 10), -10.0)
|
||||
bin_logits[0, 0] = 10.0
|
||||
bin_logits[1, -1] = 10.0
|
||||
values = convert_bins_to_continuous(bin_logits)
|
||||
assert values.shape == (2,)
|
||||
assert torch.allclose(values, torch.tensor([0.0, 1.0]), atol=1e-3)
|
||||
|
||||
|
||||
def test_decode_progress_outputs_returns_last_frame_values():
|
||||
progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]])
|
||||
success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]])
|
||||
|
||||
outputs = decode_progress_outputs(progress, success_logits, is_discrete_mode=False)
|
||||
|
||||
assert outputs["progress_pred"] == [pytest.approx([0.1, 0.9]), pytest.approx([0.4, 0.6])]
|
||||
assert outputs["success_probs"][0][-1] == pytest.approx(torch.sigmoid(torch.tensor(5.0)).item(), abs=1e-3)
|
||||
assert outputs["success_probs"][1][-1] == pytest.approx(
|
||||
torch.sigmoid(torch.tensor(-5.0)).item(), abs=1e-3
|
||||
)
|
||||
|
||||
|
||||
def test_decode_progress_outputs_discrete_mode_softmaxes_over_bins():
|
||||
# 2 frames, peaks at bin 0 and bin 9 → continuous predictions 0.0 and 1.0
|
||||
bin_logits = torch.full((1, 2, 10), -10.0)
|
||||
bin_logits[0, 0, 0] = 10.0
|
||||
bin_logits[0, 1, -1] = 10.0
|
||||
|
||||
outputs = decode_progress_outputs(bin_logits, success_logits=None, is_discrete_mode=True)
|
||||
|
||||
assert outputs["success_probs"] == []
|
||||
assert outputs["progress_pred"][0] == pytest.approx([0.0, 1.0], abs=1e-3)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_post_init_overwrites_vocab_size_with_tokenizer_length(monkeypatch):
|
||||
"""``RobometerConfig.__post_init__`` must overwrite the backbone's stale
|
||||
``text_config.vocab_size`` (which on the real Qwen3-VL config is the
|
||||
padded embedding size, ``151,936``) with ``len(tokenizer) + 5``. This is
|
||||
the contract that makes the published ``Robometer-4B`` checkpoint load
|
||||
byte-equivalently."""
|
||||
_patch_build(monkeypatch)
|
||||
|
||||
cfg = RobometerConfig(device="cpu", progress_loss_type="l2")
|
||||
|
||||
assert cfg.vlm_config["text_config"]["vocab_size"] == _EXPECTED_RESIZED_VOCAB
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_compute_reward_reads_pre_encoded_inputs(monkeypatch):
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]])
|
||||
success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]])
|
||||
_patch_build(monkeypatch)
|
||||
|
||||
cfg = RobometerConfig(device="cpu", reward_output="progress", progress_loss_type="l2")
|
||||
model = RobometerRewardModel(cfg)
|
||||
# Bypass the Qwen3-VL forward + head extraction with deterministic logits.
|
||||
monkeypatch.setattr(model, "_compute_rbm_logits", lambda _inputs: (progress, success_logits))
|
||||
|
||||
batch = _make_batch({"input_ids": torch.zeros(2, 2, dtype=torch.long)})
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert torch.allclose(rewards, torch.tensor([0.9, 0.6]))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_compute_reward_can_return_binary_success(monkeypatch):
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
progress = torch.tensor([[0.1, 0.9], [0.4, 0.6]])
|
||||
success_logits = torch.tensor([[0.0, 5.0], [0.0, -5.0]]) # sigmoid(5) > 0.5; sigmoid(-5) < 0.5
|
||||
_patch_build(monkeypatch)
|
||||
|
||||
cfg = RobometerConfig(
|
||||
device="cpu",
|
||||
reward_output="success",
|
||||
success_threshold=0.5,
|
||||
progress_loss_type="l2",
|
||||
)
|
||||
model = RobometerRewardModel(cfg)
|
||||
monkeypatch.setattr(model, "_compute_rbm_logits", lambda _inputs: (progress, success_logits))
|
||||
|
||||
batch = _make_batch({"input_ids": torch.zeros(2, 2, dtype=torch.long)})
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert torch.equal(rewards, torch.tensor([1.0, 0.0]))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_compute_reward_errors_when_inputs_missing(monkeypatch):
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
|
||||
cfg = RobometerConfig(device="cpu", progress_loss_type="l2")
|
||||
model = RobometerRewardModel(cfg)
|
||||
|
||||
with pytest.raises(KeyError, match=r"observation\.robometer\.input_ids"):
|
||||
model.compute_reward({})
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_robometer_save_pretrained_roundtrips(monkeypatch, tmp_path):
|
||||
"""Saving and reloading a Robometer model in LeRobot HF format must produce
|
||||
a single ``model.safetensors`` + ``config.json`` (no Hydra ``config.yaml``),
|
||||
must round-trip user-tunable config fields, and must persist all three
|
||||
prediction heads (``progress_head``, ``success_head``, ``preference_head``)
|
||||
so the published ``Robometer-4B`` checkpoint loads byte-equivalently.
|
||||
"""
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.rewards.robometer.modeling_robometer import RobometerRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = RobometerConfig(
|
||||
device="cpu",
|
||||
pretrained_path="robometer/Robometer-4B",
|
||||
# Knobs the user might tweak — must survive the round-trip.
|
||||
image_key="observation.images.cam_top",
|
||||
task_key="task",
|
||||
reward_output="success",
|
||||
success_threshold=0.7,
|
||||
progress_loss_type="l2",
|
||||
)
|
||||
model = RobometerRewardModel(cfg)
|
||||
model.save_pretrained(str(tmp_path))
|
||||
|
||||
# Exactly the files LeRobot's HubMixin promises.
|
||||
assert (tmp_path / CONFIG_NAME).exists()
|
||||
assert (tmp_path / SAFETENSORS_SINGLE_FILE).exists()
|
||||
assert not (tmp_path / "config.yaml").exists() # we want HF-style, not Hydra
|
||||
|
||||
# All three heads must be present in the saved safetensors. The preference
|
||||
# head is unused at inference but the published checkpoint expects its
|
||||
# rows — losing it would silently break weight loading.
|
||||
state = load_file(str(tmp_path / SAFETENSORS_SINGLE_FILE))
|
||||
assert any(k.startswith("progress_head.") for k in state), "progress_head weights missing"
|
||||
assert any(k.startswith("success_head.") for k in state), "success_head weights missing"
|
||||
assert any(k.startswith("preference_head.") for k in state), "preference_head weights missing"
|
||||
|
||||
# Reload from the local directory: no Hub fetch, no YAML overlay. The
|
||||
# base class drives subclass dispatch via the `type` field in config.json.
|
||||
reloaded_cfg = RewardModelConfig.from_pretrained(str(tmp_path))
|
||||
assert isinstance(reloaded_cfg, RobometerConfig)
|
||||
reloaded_cfg.pretrained_path = str(tmp_path) # mimic lerobot-train's `validate()`
|
||||
reloaded = RobometerRewardModel.from_pretrained(str(tmp_path), config=reloaded_cfg)
|
||||
|
||||
assert reloaded.config.image_key == "observation.images.cam_top"
|
||||
assert reloaded.config.task_key == "task"
|
||||
assert reloaded.config.reward_output == "success"
|
||||
assert reloaded.config.success_threshold == 0.7
|
||||
assert reloaded.config.progress_loss_type == "l2" # came back from config.json
|
||||
@@ -1,354 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for Robometer's pre-processing helpers and encoder step.
|
||||
|
||||
Covers the pure helpers (``_video_to_numpy`` and ``_expand_tasks``) directly,
|
||||
and exercises :class:`RobometerEncoderProcessorStep` with a stubbed
|
||||
``AutoProcessor`` so we don't need to download Qwen-VL just to test the
|
||||
dataclass plumbing (``transform_features`` / ``get_config``).
|
||||
|
||||
The full ``__call__`` path that runs ``process_vision_info`` + the Qwen
|
||||
processor is intentionally *not* covered here — it is essentially HF glue
|
||||
that's exercised by the integration / parity scripts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.rewards.robometer.processor_robometer import (
|
||||
PROGRESS_PROMPT,
|
||||
_expand_tasks,
|
||||
_frames_to_pil,
|
||||
_video_to_numpy,
|
||||
)
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
def _skip_if_robometer_extras_missing(func):
|
||||
"""Apply both optional-dependency guards in one shot.
|
||||
|
||||
``RobometerEncoderProcessorStep.__post_init__`` calls
|
||||
``require_package("transformers", ...)`` *and*
|
||||
``require_package("qwen-vl-utils", ...)``, so both need to be present
|
||||
before we can instantiate the step.
|
||||
"""
|
||||
func = skip_if_package_missing("qwen-vl-utils", import_name="qwen_vl_utils")(func)
|
||||
func = skip_if_package_missing("transformers")(func)
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _video_to_numpy — pure tensor → uint8 (T, H, W, C) conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_video_to_numpy_chw_float_is_converted_to_thwc_uint8():
|
||||
video = torch.rand(4, 3, 8, 8) # (T, C, H, W) floats in [0, 1]
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (4, 8, 8, 3)
|
||||
assert array.dtype == np.uint8
|
||||
assert array.min() >= 0 and array.max() <= 255
|
||||
|
||||
|
||||
def test_video_to_numpy_already_thwc_uint8_passes_through():
|
||||
video = torch.randint(0, 256, (3, 8, 8, 3), dtype=torch.uint8) # (T, H, W, C)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (3, 8, 8, 3)
|
||||
assert array.dtype == np.uint8
|
||||
|
||||
|
||||
def test_video_to_numpy_max_frames_tail_crops_recent_frames():
|
||||
"""``max_frames`` should keep the **last** K frames (most recent)."""
|
||||
video = torch.zeros(10, 3, 4, 4)
|
||||
for t in range(10):
|
||||
video[t] = t / 9.0 # marker: 0 at t=0, ≈1 at t=9
|
||||
|
||||
array = _video_to_numpy(video, max_frames=3)
|
||||
|
||||
assert array.shape == (3, 4, 4, 3)
|
||||
# The first kept frame is t=7 → marker ≈ 7/9 → uint8 ≈ 198
|
||||
assert int(array[0, 0, 0, 0]) == int(round(7 / 9 * 255))
|
||||
# The last kept frame is t=9 → marker = 1.0 → uint8 = 255
|
||||
assert int(array[-1, 0, 0, 0]) == 255
|
||||
|
||||
|
||||
def test_video_to_numpy_rejects_3d_input():
|
||||
with pytest.raises(ValueError, match="Expected channel dim"):
|
||||
_video_to_numpy(torch.zeros(4, 8, 8), max_frames=None)
|
||||
|
||||
|
||||
def test_video_to_numpy_floats_above_one_pass_through_without_rescaling():
|
||||
"""If ``array.max() > 1`` the helper assumes the tensor is already in the
|
||||
[0, 255] range (uint8-as-float), so values pass through unchanged."""
|
||||
video = torch.full((1, 3, 2, 2), 5.0)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert array.shape == (1, 2, 2, 3)
|
||||
assert int(array.max()) == 5
|
||||
|
||||
|
||||
def test_video_to_numpy_clips_very_large_floats_to_uint8_max():
|
||||
"""Out-of-uint8-range floats are clipped at 255 before the cast."""
|
||||
video = torch.full((1, 3, 2, 2), 300.0)
|
||||
array = _video_to_numpy(video, max_frames=None)
|
||||
|
||||
assert int(array.max()) == 255
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expand_tasks — string / list / tuple broadcasting to batch size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_expand_tasks_string_is_broadcast_to_batch_size():
|
||||
assert _expand_tasks("pick up", batch_size=3, default=None) == ["pick up", "pick up", "pick up"]
|
||||
|
||||
|
||||
def test_expand_tasks_list_of_matching_size_passes_through():
|
||||
assert _expand_tasks(["a", "b", "c"], batch_size=3, default=None) == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_expand_tasks_tuple_is_normalised_to_list():
|
||||
assert _expand_tasks(("a", "b"), batch_size=2, default=None) == ["a", "b"]
|
||||
|
||||
|
||||
def test_expand_tasks_single_element_list_is_broadcast():
|
||||
assert _expand_tasks(["only one"], batch_size=3, default=None) == ["only one"] * 3
|
||||
|
||||
|
||||
def test_expand_tasks_size_mismatch_raises():
|
||||
with pytest.raises(ValueError, match="Expected 3 tasks"):
|
||||
_expand_tasks(["a", "b"], batch_size=3, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_missing_uses_default():
|
||||
assert _expand_tasks(None, batch_size=2, default="fallback") == ["fallback", "fallback"]
|
||||
|
||||
|
||||
def test_expand_tasks_missing_without_default_raises():
|
||||
with pytest.raises(KeyError, match="task description"):
|
||||
_expand_tasks(None, batch_size=1, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_wrong_type_raises():
|
||||
with pytest.raises(TypeError, match="must be a string or list"):
|
||||
_expand_tasks(42, batch_size=1, default=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _frames_to_pil — uint8 (T, H, W, C) → list[PIL.Image]
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_frames_to_pil_returns_one_image_per_frame():
|
||||
frames = np.zeros((4, 8, 8, 3), dtype=np.uint8)
|
||||
images = _frames_to_pil(frames)
|
||||
|
||||
assert len(images) == 4
|
||||
assert all(img.size == (8, 8) for img in images)
|
||||
|
||||
|
||||
def test_frames_to_pil_casts_floats_to_uint8():
|
||||
frames = np.full((2, 4, 4, 3), 200.0, dtype=np.float32)
|
||||
images = _frames_to_pil(frames)
|
||||
|
||||
assert len(images) == 2
|
||||
# PIL converted from clipped uint8 - sanity check pixel values come through.
|
||||
assert np.asarray(images[0]).dtype == np.uint8
|
||||
|
||||
|
||||
def test_frames_to_pil_rejects_non_4d_input():
|
||||
with pytest.raises(ValueError, match=r"\(T,H,W,C\)"):
|
||||
_frames_to_pil(np.zeros((4, 8, 8), dtype=np.uint8))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder step plumbing — exercise dataclass surface with a stubbed AutoProcessor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
"""Tokenizer surface the encoder step touches in ``__post_init__``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pad_token: str | None = None
|
||||
self.eos_token = "<|endoftext|>"
|
||||
self._vocab: dict[str, int] = {"<|endoftext|>": 0}
|
||||
self.added: list[str] = []
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return self._vocab
|
||||
|
||||
def add_special_tokens(self, payload: dict[str, Any]) -> int:
|
||||
for token in payload.get("additional_special_tokens", []):
|
||||
if token not in self._vocab:
|
||||
self._vocab[token] = len(self._vocab)
|
||||
self.added.append(token)
|
||||
return len(self.added)
|
||||
|
||||
|
||||
class _FakeAutoProcessor:
|
||||
"""Stand-in returned by ``AutoProcessor.from_pretrained`` during tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = _FakeTokenizer()
|
||||
self.image_processor = None
|
||||
self.video_processor = None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
|
||||
def _build_step(monkeypatch, **overrides):
|
||||
from lerobot.rewards.robometer import processor_robometer
|
||||
|
||||
monkeypatch.setattr(processor_robometer, "AutoProcessor", _FakeAutoProcessor)
|
||||
|
||||
return processor_robometer.RobometerEncoderProcessorStep(**overrides)
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_registers_special_tokens_on_tokenizer(monkeypatch):
|
||||
"""``__post_init__`` must register Robometer's five special tokens on the
|
||||
tokenizer that ships with the chosen Qwen-VL checkpoint."""
|
||||
from lerobot.rewards.robometer.configuration_robometer import ROBOMETER_SPECIAL_TOKENS
|
||||
|
||||
step = _build_step(monkeypatch)
|
||||
|
||||
vocab = step._processor.tokenizer.get_vocab()
|
||||
for token in ROBOMETER_SPECIAL_TOKENS:
|
||||
assert token in vocab, f"{token} not registered on the tokenizer"
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_sets_pad_token_to_eos_when_missing(monkeypatch):
|
||||
"""Qwen tokenizers ship without a pad token; the step must reuse EOS so
|
||||
batched processing doesn't crash on padding."""
|
||||
step = _build_step(monkeypatch)
|
||||
|
||||
assert step._processor.tokenizer.pad_token == "<|endoftext|>"
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_get_config_roundtrips_user_fields(monkeypatch):
|
||||
"""``get_config`` must serialise every user-tunable field — these are what
|
||||
the processor pipeline saves under ``preprocessor_config.json``."""
|
||||
step = _build_step(
|
||||
monkeypatch,
|
||||
base_model_id="Qwen/Qwen3-VL-4B-Instruct",
|
||||
image_key="observation.images.cam_top",
|
||||
task_key="task",
|
||||
default_task="do the thing",
|
||||
max_frames=12,
|
||||
use_multi_image=True,
|
||||
use_per_frame_progress_token=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
cfg = step.get_config()
|
||||
assert cfg == {
|
||||
"base_model_id": "Qwen/Qwen3-VL-4B-Instruct",
|
||||
"image_key": "observation.images.cam_top",
|
||||
"task_key": "task",
|
||||
"default_task": "do the thing",
|
||||
"max_frames": 12,
|
||||
"use_multi_image": True,
|
||||
"use_per_frame_progress_token": True,
|
||||
"max_length": 2048,
|
||||
}
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_transform_features_is_identity(monkeypatch):
|
||||
"""The encoder step writes Qwen tensors into ``observation`` at call time,
|
||||
but it does **not** advertise new typed features at pipeline-build time —
|
||||
the downstream model consumes them via the ``ROBOMETER_FEATURE_PREFIX``
|
||||
namespace, not via the typed feature map.
|
||||
"""
|
||||
step = _build_step(monkeypatch)
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL),
|
||||
}
|
||||
}
|
||||
assert step.transform_features(features) == features
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_build_conversation_inserts_prog_token_per_frame(monkeypatch):
|
||||
"""In multi-image mode with per-frame progress tokens, the conversation
|
||||
must alternate ``image`` and ``<|prog_token|>`` text entries, one pair
|
||||
per frame, after the task prompt."""
|
||||
step = _build_step(
|
||||
monkeypatch,
|
||||
use_multi_image=True,
|
||||
use_per_frame_progress_token=True,
|
||||
)
|
||||
|
||||
frames = np.zeros((3, 8, 8, 3), dtype=np.uint8)
|
||||
conversation = step._build_conversation(frames, task="pick up the cube")
|
||||
|
||||
assert len(conversation) == 1 and conversation[0]["role"] == "user"
|
||||
content = conversation[0]["content"]
|
||||
|
||||
# First entry is the task prompt.
|
||||
assert content[0] == {"type": "text", "text": PROGRESS_PROMPT.format(task="pick up the cube")}
|
||||
|
||||
# Then 3 (image, <|prog_token|>) pairs.
|
||||
expected_tail = [
|
||||
item
|
||||
for _ in range(3)
|
||||
for item in (
|
||||
{"type": "image"}, # value asserted below
|
||||
{"type": "text", "text": "<|prog_token|>"},
|
||||
)
|
||||
]
|
||||
assert len(content) == 1 + len(expected_tail)
|
||||
for got, exp in zip(content[1:], expected_tail, strict=True):
|
||||
assert got["type"] == exp["type"]
|
||||
if exp["type"] == "text":
|
||||
assert got["text"] == exp["text"]
|
||||
|
||||
|
||||
@_skip_if_robometer_extras_missing
|
||||
def test_encoder_step_build_conversation_video_mode_uses_single_video_entry(monkeypatch):
|
||||
"""When ``use_multi_image=False``, frames are bundled into a single
|
||||
``video`` content entry instead of individual ``image`` entries."""
|
||||
step = _build_step(
|
||||
monkeypatch,
|
||||
use_multi_image=False,
|
||||
use_per_frame_progress_token=False,
|
||||
)
|
||||
|
||||
frames = np.zeros((4, 8, 8, 3), dtype=np.uint8)
|
||||
conversation = step._build_conversation(frames, task="pour the water")
|
||||
|
||||
content = conversation[0]["content"]
|
||||
# Exactly two entries: the prompt and one video entry.
|
||||
assert len(content) == 2
|
||||
assert content[0]["type"] == "text"
|
||||
assert content[1]["type"] == "video"
|
||||
# The video entry carries all four frames.
|
||||
assert len(content[1]["video"]) == 4
|
||||
@@ -2989,11 +2989,6 @@ rebot = [
|
||||
{ name = "motorbridge" },
|
||||
{ name = "motorbridge-smart-servo" },
|
||||
]
|
||||
robometer = [
|
||||
{ name = "peft" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
robstride = [
|
||||
{ name = "python-can" },
|
||||
]
|
||||
@@ -3052,11 +3047,6 @@ video-benchmark = [
|
||||
viz = [
|
||||
{ name = "rerun-sdk" },
|
||||
]
|
||||
vla-jepa = [
|
||||
{ name = "diffusers" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
wallx = [
|
||||
{ name = "peft" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
@@ -3125,7 +3115,6 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||
@@ -3157,7 +3146,6 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'molmoact2'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'robometer'" },
|
||||
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["phone"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["pi"], marker = "extra == 'all'" },
|
||||
@@ -3175,13 +3163,10 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'lekiwi'" },
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'robometer'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["robometer"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["robstride"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["sarm"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" },
|
||||
@@ -3203,18 +3188,15 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'peft'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'robometer'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
||||
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
||||
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
||||
@@ -3276,7 +3258,7 @@ requires-dist = [
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
|
||||
Reference in New Issue
Block a user