Compare commits

..

10 Commits

Author SHA1 Message Date
Maxime Ellerbach ecf342d481 small fix for the preprocessor and padded images 2026-06-16 11:27:51 +00:00
Maxime Ellerbach 1e762d5240 linting 2026-06-15 12:11:39 +00:00
Maxime Ellerbach 35c3302f4d re-parenting of some layers to enable proper zero-3 FSDP 2026-06-15 12:11:27 +00:00
Maxime Ellerbach a323ea67b6 preparing for training adding some temporary debug code aswell to visualize model output 2026-06-12 15:25:28 +00:00
Maxime Ellerbach 7c063c3fbc changing reproducable results 2026-06-12 08:57:11 +00:00
Maxime Ellerbach 9cf12c941d big refactor to use models from diffusers and transformers 2026-06-12 08:56:58 +00:00
ZibinDong 4039da81c6 Add FastWAM policy review updates 2026-06-09 13:37:59 +00:00
ZibinDong b3a28a49f6 Add FastWAM policy 2026-06-09 13:37:59 +00:00
Adil Zouitine 49755a3d9e feat(processor): Add in-memory processor pipeline serialization (#3732)
* feat(processor): add in-memory pipeline serialization

Expose processor pipeline config and tensor state without requiring temporary files, so processors can be transported, compared, or hashed directly in memory.

* feat(processor): enhance DataProcessorPipeline with registry support

- Added a new RegisteredLazyTensorStateStep for registry-based serialization tests.
- Improved state filename handling in _get_state_filename method.
- Refactored validation logic in _validate_loaded_config to simplify parameter types.
- Updated tests to verify registry step functionality and ensure correct state loading.

* refactor(processor): update state handling in DataProcessorPipeline

- Introduced a new static method _get_state_key to derive in-memory state keys from serialized filenames.
- Updated state_dict and load_state_dict methods to use suffixless state keys instead of filenames.
- Adjusted related tests to reflect changes in state key handling, ensuring consistency in state management

* fix(processor): update loaded_config argument description in DataProcessorPipeline

- Clarified the documentation for the loaded_config parameter to indicate that it may be a non-dictionary value, enhancing understanding for future developers.
2026-06-08 11:27:24 +02:00
Maxime Ellerbach 09808183ca feat(rollout): adding episodic strategy (#3717)
* feat(rollout): adding legacy strategy

* adding legacy to existing tests

* updating docs and docstring

* changing misleading docstring

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding extra guard like dagged with try except finally

* Potential fix for pull request finding

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>

* adding reset to initial position

* moving smooth teleop handover to control_utils and adding this behavior to legacy strategy

* reducing duration of the handover

* * renaming to episodic
* changing semantics of the docstring
* fixing leader - follower handover disable torque
* adding optionnal config to disable handover

* wiring the smooth_leader_follower_handover config

* renaming config smooth_leader_to_follower_handover

---------

Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-06-06 00:32:38 +02:00
60 changed files with 10107 additions and 6203 deletions
+1 -1
View File
@@ -105,7 +105,7 @@ lerobot-train \
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+3 -1
View File
@@ -67,8 +67,10 @@
title: VLA-JEPA
- local: eo1
title: EO-1
- local: fastwam
title: FastWAM
- local: groot
title: NVIDIA GR00T
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
+1 -1
View File
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
- [SmolVLA](./smolvla)
- [Pi0.5](./pi05)
- [GR00T N1.7](./groot)
- [GR00T N1.5](./groot)
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
+163
View File
@@ -0,0 +1,163 @@
# FastWAM
FastWAM is a World Action Model policy for robot control. The LeRobot integration exposes FastWAM through the standard policy API so it can be configured with `policy.type=fastwam`, trained with `lerobot-train`, and loaded through the LeRobot pretrained policy interface.
## Model Overview
FastWAM keeps video modeling during training, but uses direct action prediction at inference time instead of iteratively generating future observations. This LeRobot policy wraps the FastWAM action model, adapts LeRobot batches to FastWAM training samples, and provides the standard processor pipeline for normalization and action postprocessing.
The implementation initializes the visual world-model components from `Wan-AI/Wan2.2-TI2V-5B` by default and predicts action chunks with shape `[batch, action_horizon, action_dim]`.
### What the LeRobot Integration Covers
- Standard `policy.type=fastwam` configuration through LeRobot
- Image, state, action, and language-task batch adaptation
- Action chunk inference through `select_action` and `predict_action_chunk`
- Checkpoint save/load through the LeRobot policy APIs
- Configurable LIBERO gripper action postprocessing
## Installation Requirements
Install LeRobot from source, then install FastWAM dependencies:
```bash
pip install -e ".[fastwam]"
```
This installs the FastWAM policy extra from `pyproject.toml`: `transformers`,
`diffusers`, `ftfy`, and `regex`, plus LeRobot's base dependencies.
For LIBERO evaluation, install the benchmark dependencies too:
```bash
pip install -e ".[fastwam,libero]"
```
This installs both extras. In addition to the FastWAM dependencies above, the
`libero` extra installs LeRobot dataset dependencies, `hf-libero` on Linux, and
`scipy`.
FastWAM uses the Wan2.2 TI2V backbone. The default model id is:
```python
policy.model_id=Wan-AI/Wan2.2-TI2V-5B
```
## Data Requirements
FastWAM expects a LeRobot dataset with:
- one or more visual observations whose widths concatenate to `policy.image_size[1]`
- `observation.state` when `policy.proprio_dim` is not `None`
- `action`
- a language task instruction through the dataset task field, or precomputed `context` and `context_mask` tensors
The default visual setup is one image feature named `observation.images.image` with shape `(3, 224, 448)`. If the dataset uses two cameras, configure `policy.input_features` so their heights match `224` and their widths sum to `448`.
## Usage
Create a new FastWAM policy with:
```bash
lerobot-train \
--dataset.repo_id=your-org/your-dataset \
--policy.type=fastwam \
--policy.action_dim=7 \
--policy.proprio_dim=8 \
--policy.action_horizon=32 \
--policy.n_action_steps=10 \
--policy.image_size='[224,448]' \
--output_dir=./outputs/fastwam_training \
--job_name=fastwam_training \
--steps=300000 \
--batch_size=8 \
--policy.device=cuda
```
Evaluate an existing LeRobot-format checkpoint on LIBERO-10 with:
```bash
lerobot-eval \
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
--policy.device=cuda \
--policy.torch_dtype=float32 \
--policy.n_action_steps=10 \
--env.type=libero \
--env.task=libero_10 \
--env.observation_height=224 \
--env.observation_width=224 \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=0 \
--env.episode_length=600
```
For `libero_goal`, `libero_spatial`, and `libero_object`, use
`--env.episode_length=300`.
For real-robot rollout, use the same checkpoint path:
```bash
lerobot-rollout \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \
--policy.path=your-org/fastwam-real-robot
```
## Configuration Notes
### Image Features
`policy.image_size` is the size of the concatenated FastWAM image tensor as `(height, width)`. Each configured image feature must have shape `(3, height, camera_width)`, and all camera widths must sum to the configured width.
### Action Chunking
`policy.action_horizon` controls the number of future actions supervised during training and predicted during inference. `policy.n_action_steps` controls how many actions are consumed before the policy predicts a fresh chunk. `policy.n_action_steps` must be less than or equal to `policy.action_horizon`.
### Wan Components
FastWAM loads the Wan VAE, video DiT, text encoder, and tokenizer from the configured Wan model directory or Hugging Face Hub model id. LeRobot-format FastWAM checkpoints saved by `save_pretrained` also copy the local Wan component files needed by `from_pretrained`.
### LIBERO Action Toggle
FastWAM LIBERO checkpoints use `policy.toggle_action_dimensions=[-1]` by
default to match the gripper action convention used by the original FastWAM
evaluation pipeline:
```bash
--policy.toggle_action_dimensions='[-1]'
```
## Results
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 97.6% | 500 |
| libero_object | 99.0% | 500 |
| libero_goal | 95.0% | 500 |
| libero_10 | 94.0% | 500 |
| **average** | **96.4%** | 2000 |
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300` (1x H20 140 GB).
## References
- [Fast-WAM paper](https://arxiv.org/abs/2603.16666)
- [Fast-WAM project page](https://yuantianyuan01.github.io/FastWAM/)
- [Fast-WAM code](https://github.com/yuantianyuan01/FastWAM)
- [Released upstream checkpoints](https://huggingface.co/yuanty/fastwam)
- [Wan2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
## Citation
```bibtex
@article{yuan2026fastwam,
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
journal = {arXiv preprint arXiv:2603.16666},
year = {2026},
url = {https://arxiv.org/abs/2603.16666}
}
```
+33 -78
View File
@@ -1,19 +1,16 @@
# GR00T Policy
# GR00T N1.5 Policy
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints and configs are rejected with a migration note. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
This document outlines the specifics of its integration and usage within the LeRobot framework.
## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
@@ -31,43 +28,33 @@ This approach allows the model to be highly adaptable through post-training for
## Installation Requirements
GR00T is intended for NVIDIA GPU-accelerated systems. Install LeRobot with the GR00T extra:
As of today, GR00T N1.5 requires flash attention for it's internal working.
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
```bash
pip install "lerobot[groot]"
```
For a source checkout:
```bash
pip install -e ".[groot]"
```
### Optional: Flash Attention acceleration
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
2. Install lerobot with the groot extra.
3. Install LeRobot by running:
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
```bash
pip install lerobot[groot]
```
## Usage
To use GR00T N1.7:
To use GR00T in your LeRobot configuration, specify the policy type as:
```bash
--policy.type=groot
```python
policy.type=groot
```
## Training
@@ -100,53 +87,21 @@ accelerate launch \
## Performance Results
### LIBERO Benchmark Results
### Libero Benchmark Results
> [!NOTE]
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
> Follow our instructions for Libero usage: [Libero](./libero)
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
### GR00T N1.7 LIBERO Checkpoints
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
| Suite | Checkpoint subdirectory |
| -------------- | ----------------------- |
| LIBERO Spatial | `libero_spatial` |
| LIBERO Object | `libero_object` |
| LIBERO Goal | `libero_goal` |
| LIBERO 10 | `libero_10` |
Preliminary LeRobot integration results:
| Suite | Status | Success rate | n_episodes |
| -------------- | ------ | -----------: | ---------: |
| LIBERO Spatial | ✓ | ~95% | XX |
| LIBERO Object | ✓ | XX% | XX |
| LIBERO Goal | ✓ | XX% | XX |
| LIBERO 10 | ✓ | XX% | XX |
| **Average** | ✓ | **XX%** | **XX** |
Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
hf download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
lerobot-eval \
--policy.type=groot \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \
--eval.n_episodes=50
```
Use `eval.n_episodes >= 50` per suite when reporting success rates.
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
### Evaluate in your hardware setup
@@ -176,4 +131,4 @@ lerobot-rollout\
## License
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
+1
View File
@@ -647,5 +647,6 @@ The `--strategy.type` flag selects the execution mode:
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
- `episodic`: Episode-oriented policy recording with reset phases between episodes
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
+38
View File
@@ -157,6 +157,44 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
| `--teleop.type` | **Required.** Teleoperator type |
### Episodic (`--strategy.type=episodic`)
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
```bash
lerobot-rollout \
--strategy.type=episodic \
--policy.path=${HF_USER}/my_policy \
--robot.type=so100_follower \
--robot.port=/dev/ttyACM0 \
--teleop.type=so100_leader \
--teleop.port=/dev/ttyACM1 \
--dataset.repo_id=${HF_USER}/my_eval_data \
--dataset.num_episodes=20 \
--dataset.episode_time_s=30 \
--dataset.reset_time_s=10 \
--dataset.single_task="Pick up the red cube"
```
Teleop is optional — if omitted the robot holds its position during the reset phase.
**Keyboard controls:**
| Key | Action |
| ----------- | -------------------------------- |
| `→` (right) | End the current episode early |
| `←` (left) | Discard episode and re-record it |
| `ESC` | Stop the recording session |
| Flag | Description |
| ----------------------------------------------- | -------------------------------------------------------------------------- |
| `--dataset.num_episodes` | Number of episodes to record |
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
---
## Inference Backends
+56
View File
@@ -0,0 +1,56 @@
## Research Paper
Paper: https://arxiv.org/abs/2603.16666
## Repository
Code: https://github.com/yuantianyuan01/FastWAM
Project page: https://yuantianyuan01.github.io/FastWAM/
## Citation
```bibtex
@article{yuan2026fastwam,
title = {Fast-WAM: Do World Action Models Need Test-time Future Imagination?},
author = {Tianyuan Yuan and Zibin Dong and Yicheng Liu and Hang Zhao},
journal = {arXiv preprint arXiv:2603.16666},
year = {2026},
url = {https://arxiv.org/abs/2603.16666}
}
```
## Additional Resources
Base video model: https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B
Released upstream checkpoints: https://huggingface.co/yuanty/fastwam
## Results
Evaluated on LIBERO with [`ZibinDong/fastwam_libero_uncond_2cam224`](https://huggingface.co/ZibinDong/fastwam_libero_uncond_2cam224):
| Suite | Success rate | n_episodes |
| -------------- | -----------: | ---------: |
| libero_spatial | 97.6% | 500 |
| libero_object | 99.0% | 500 |
| libero_goal | 95.0% | 500 |
| libero_10 | 94.0% | 500 |
| **average** | **96.4%** | 2000 |
Reproduce: `lerobot-eval --policy.path=ZibinDong/fastwam_libero_uncond_2cam224 --policy.device=cuda --policy.torch_dtype=float32 --policy.n_action_steps=10 --env.type=libero --env.task=libero_spatial --env.observation_height=256 --env.observation_width=256 --eval.batch_size=1 --eval.n_episodes=50 --seed=0 --env.episode_length=300`.
For LIBERO-10, use `--env.task=libero_10 --env.episode_length=600`:
```bash
lerobot-eval \
--policy.path=ZibinDong/fastwam_libero_uncond_2cam224 \
--policy.device=cuda \
--policy.torch_dtype=float32 \
--policy.n_action_steps=10 \
--env.type=libero \
--env.task=libero_10 --env.observation_height=256 --env.observation_width=256 \
--eval.batch_size=1 \
--eval.n_episodes=50 \
--seed=0 --env.episode_length=600
```
+2 -108
View File
@@ -1,13 +1,6 @@
## Research Paper
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
@@ -31,103 +24,4 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Models:
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
over every embodiment tag present in the checkpoint:
1. **Model parity** — given byte-identical pre-processed inputs and the same
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
### Why two environments
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
is itself a defaulted dataclass, so the original config dataclasses fail to import
(`non-default argument follows default argument`). The two implementations therefore
**cannot be imported in the same Python process**.
So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves to one `.npz`
per tag: the raw observations (`raw::` keys), the exact collated inputs
(`in::` keys), the seed, and the raw `action_pred`.
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
`.npz`; the model-parity case replays the byte-identical collated inputs through
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
### Fairness controls
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed right before sampling on both sides; the
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
### How to run
```bash
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
CKPT=$(python - <<'PY'
import os
from huggingface_hub import snapshot_download
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
allow_patterns=["libero_10/*"]), "libero_10"))
PY
)
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
tests/policies/groot/utils/dump_original_n1_7.py \
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
```
The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by
the producer; they are never committed. The tests **skip** (do not fail) on CI or
when the checkpoint / artifacts are absent.
#### Env knobs (all optional)
| Var | Default | Purpose |
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
+8 -1
View File
@@ -208,12 +208,18 @@ groot = [
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
fastwam = [
"lerobot[transformers-dep]",
"lerobot[diffusers-dep]",
]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
@@ -278,7 +284,8 @@ all = [
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
"lerobot[groot]",
"lerobot[fastwam]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
+70
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
# Utilities
########################################################################################
import logging
import time
import traceback
from contextlib import nullcontext
from copy import copy
@@ -243,3 +244,72 @@ def sanity_check_dataset_robot_compatibility(
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
)
########################################################################################
# Teleoperator smooth handover helpers
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
# being a root module.
########################################################################################
def teleop_supports_feedback(teleop) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback (i.e. have non-empty
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
``target_pos`` is expected to be in the teleop's action/feedback key space.
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
the robot action key space directly.
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
follower robot does not receive new actions, which could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def follower_smooth_move_to(
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm to
the follower, the follower is brought to the teleop's current pose so the
robot meets the operator's hand rather than jumping to it on the first frame.
Both ``current`` and ``target`` must be in the robot action key space
(i.e. the output of ``robot_action_processor``).
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
+2
View File
@@ -18,6 +18,7 @@ from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .eo1.configuration_eo1 import EO1Config as EO1Config
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .fastwam.configuration_fastwam import FastWAMConfig as FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config as MolmoAct2Config
@@ -42,6 +43,7 @@ __all__ = [
"ACTConfig",
"DiffusionConfig",
"EO1Config",
"FastWAMConfig",
"GaussianActorConfig",
"GrootConfig",
"MolmoAct2Config",
+33 -14
View File
@@ -47,6 +47,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .fastwam.configuration_fastwam import FastWAMConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
@@ -162,6 +163,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
elif name == "fastwam":
from .fastwam.modeling_fastwam import FastWAMPolicy
return FastWAMPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -218,6 +223,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
elif policy_type == "fastwam":
return FastWAMConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -280,22 +287,26 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
return make_groot_pre_post_processors_from_pretrained(
config=policy_cfg,
pretrained_path=pretrained_path,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
preprocessor_config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
postprocessor_config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
)
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -444,6 +455,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, FastWAMConfig):
from .fastwam.processor_fastwam import make_fastwam_pre_post_processors
processors = make_fastwam_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
+1
View File
@@ -0,0 +1 @@
../../../../docs/source/policy_fastwam_README.md
+23
View File
@@ -0,0 +1,23 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_fastwam import FastWAMConfig
from .modeling_fastwam import FastWAMPolicy
from .processor_fastwam import make_fastwam_pre_post_processors
__all__ = [
"FastWAMConfig",
"FastWAMPolicy",
"make_fastwam_pre_post_processors",
]
@@ -0,0 +1,394 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.configs import (
FeatureType,
NormalizationMode,
PolicyFeature,
PreTrainedConfig,
)
from lerobot.optim import AdamWConfig
from lerobot.utils.constants import ACTION, OBS_STATE
WAN22_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B"
FASTWAM_BASE_MODEL_ID = "lerobot/fastwam-base"
_FASTWAM_VIDEO_BASE_COMPAT_KEYS = (
"patch_size",
"in_dim",
"hidden_dim",
"ffn_dim",
"freq_dim",
"text_dim",
"out_dim",
"num_heads",
"attn_head_dim",
"num_layers",
)
_FASTWAM_ACTION_BASE_COMPAT_KEYS = (
"hidden_dim",
"ffn_dim",
"num_heads",
"attn_head_dim",
"num_layers",
"text_dim",
"freq_dim",
)
def default_video_dit_config(action_dim: int) -> dict[str, Any]:
return {
"patch_size": [1, 2, 2],
"in_dim": 48,
"hidden_dim": 3072,
"ffn_dim": 14336,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 48,
"num_heads": 24,
"attn_head_dim": 128,
"num_layers": 30,
"eps": 1.0e-6,
"separated_timestep": True,
"use_gradient_checkpointing": False,
"video_attention_mask_mode": "first_frame_causal",
"action_conditioned": False,
"action_dim": action_dim,
"action_group_causal_mask_mode": "group_diagonal",
"fp32_attention": True,
}
def default_action_dit_config(action_dim: int) -> dict[str, Any]:
return {
"action_dim": action_dim,
"hidden_dim": 1024,
"ffn_dim": 4096,
"num_heads": 24,
"attn_head_dim": 128,
"num_layers": 30,
"text_dim": 4096,
"freq_dim": 256,
"eps": 1.0e-6,
"use_gradient_checkpointing": False,
"fp32_attention": True,
}
def _coerce_enum(enum_cls: type, value: Any) -> Any:
if isinstance(value, enum_cls):
return value
try:
return enum_cls(value)
except (TypeError, ValueError):
return getattr(enum_cls, str(value), value)
def _coerce_policy_features(features: dict[str, Any] | None) -> dict[str, PolicyFeature] | None:
if features is None:
return None
coerced = {}
for name, feature in features.items():
if isinstance(feature, PolicyFeature):
coerced[name] = feature
continue
coerced[name] = PolicyFeature(
type=_coerce_enum(FeatureType, feature["type"]),
shape=tuple(feature["shape"]),
)
return coerced
def _is_local_model_id(value: str) -> bool:
path = Path(value).expanduser()
return path.is_absolute() or value.startswith(("./", "../", "~")) or path.exists()
def _validate_wan_model_id(value: str, field_name: str) -> str:
if value == WAN22_MODEL_ID or _is_local_model_id(value):
return value
raise ValueError(f"`{field_name}` must be `{WAN22_MODEL_ID}` or an explicit local path, got `{value}`.")
def is_fastwam_base_compatible_config(config: FastWAMConfig) -> bool:
"""Return whether `fastwam-base` partial weights can initialize this config."""
default_video_config = default_video_dit_config(config.action_dim)
default_action_config = default_action_dit_config(config.action_dim)
return all(
config.video_dit_config.get(key) == default_video_config.get(key)
for key in _FASTWAM_VIDEO_BASE_COMPAT_KEYS
) and all(
config.action_dit_config.get(key) == default_action_config.get(key)
for key in _FASTWAM_ACTION_BASE_COMPAT_KEYS
)
@PreTrainedConfig.register_subclass("fastwam")
@dataclass
class FastWAMConfig(PreTrainedConfig):
"""Configuration for the FastWAM LeRobot policy.
Args:
action_dim (int): Number of scalar action channels per timestep.
proprio_dim (int | None): Number of proprioception channels used as an
extra text-context token. `None` disables proprio conditioning.
action_horizon (int): Number of actions predicted by one policy call.
num_video_frames (int): Raw video sampling window (in dataset frames). The
model actually operates on `model_video_frames` frames after subsampling
by `action_video_freq_ratio`.
action_video_freq_ratio (int): Actions are sampled at this multiple of the
video frame rate. Video frames are taken every `action_video_freq_ratio`-th
raw frame, so the model sees `(num_video_frames - 1) // ratio + 1` frames
spanning the same time window as `action_horizon` actions (ratio actions
per video frame).
image_size (tuple[int, int]): Concatenated image size as `(height, width)`.
context_len (int): Maximum text embedding token length.
video_dit_config (dict[str, Any] | None): Wan video expert config.
action_dit_config (dict[str, Any] | None): Action expert config.
use_gradient_checkpointing (bool): Enable activation checkpointing in both DiT
experts (trades compute for memory; propagated into the DiT configs).
freeze_video_expert (bool): Freeze the ~5B Wan video expert
(`model.video_expert`) so only the action expert + proprio encoder train.
Cuts the AdamW optimizer footprint substantially; the video expert keeps its
pretrained weights. (If enabled, also set `loss.lambda_video=0` to skip the
now-gradient-free video loss compute.)
"""
n_obs_steps: int = 1
action_dim: int = 7
proprio_dim: int | None = 8
action_horizon: int = 32
n_action_steps: int = 32
num_video_frames: int = 33
action_video_freq_ratio: int = 4
image_size: tuple[int, int] = (224, 448)
context_len: int = 128
model_id: str = WAN22_MODEL_ID
tokenizer_model_id: str = WAN22_MODEL_ID
base_model_id: str | None = FASTWAM_BASE_MODEL_ID
tokenizer_max_len: int = 128
load_text_encoder: bool = True
mot_checkpoint_mixed_attn: bool = False
torch_dtype: str = "bfloat16"
prompt_template: str = (
"A video recorded from a robot's point of view executing the following instruction: {task}"
)
num_inference_steps: int = 10
inference_seed: int | None = 42
rand_device: str = "cpu"
text_cfg_scale: float = 1.0
negative_prompt: str = ""
sigma_shift: float | None = None
tiled: bool = False
fp32_attention: bool = True
use_gradient_checkpointing: bool = False
freeze_video_expert: bool = False
toggle_action_dimensions: list[int] = field(default_factory=list)
video_scheduler: dict[str, float | int] = field(
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
)
action_scheduler: dict[str, float | int] = field(
default_factory=lambda: {"train_shift": 5.0, "infer_shift": 5.0, "num_train_timesteps": 1000}
)
loss: dict[str, float] = field(default_factory=lambda: {"lambda_video": 1.0, "lambda_action": 1.0})
video_dit_config: dict[str, Any] | None = None
action_dit_config: dict[str, Any] | None = None
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
input_features: dict[str, PolicyFeature] | None = None
output_features: dict[str, PolicyFeature] | None = None
optimizer_lr: float = 1.0e-4
optimizer_weight_decay: float = 1.0e-2
def __post_init__(self) -> None:
super().__post_init__()
self.image_size = tuple(self.image_size)
self.model_id = _validate_wan_model_id(self.model_id, "model_id")
self.tokenizer_model_id = _validate_wan_model_id(self.tokenizer_model_id, "tokenizer_model_id")
self.input_features = _coerce_policy_features(self.input_features)
self.output_features = _coerce_policy_features(self.output_features)
self.toggle_action_dimensions = [int(dim) for dim in self.toggle_action_dimensions]
self.video_dit_config = self.video_dit_config or default_video_dit_config(self.action_dim)
self.action_dit_config = self.action_dit_config or default_action_dit_config(self.action_dim)
self.video_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.action_dit_config["fp32_attention"] = bool(self.fp32_attention)
self.video_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
self.action_dit_config["use_gradient_checkpointing"] = bool(self.use_gradient_checkpointing)
if self.input_features is None:
height, width = self.image_size
self.input_features = {
"observation.images.image": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
}
if self.proprio_dim is not None:
self.input_features[OBS_STATE] = PolicyFeature(
type=FeatureType.STATE,
shape=(self.proprio_dim,),
)
if self.output_features is None:
self.output_features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,))}
self.validate_features()
if self.pretrained_path or self.use_peft or not self.base_model_id:
return
if not is_fastwam_base_compatible_config(self):
return
self.pretrained_path = Path(self.base_model_id)
self._auto_pretrained_path = True
def _save_pretrained(self, save_directory: Path) -> None:
if not getattr(self, "_auto_pretrained_path", False):
super()._save_pretrained(save_directory)
return
pretrained_path = self.pretrained_path
self.pretrained_path = None
try:
super()._save_pretrained(save_directory)
finally:
self.pretrained_path = pretrained_path
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
def get_scheduler_preset(self) -> None:
return None
def set_dataset_feature_metadata(self, dataset_features: dict[str, Any]) -> None:
"""Rebuild visual input features from the dataset's real camera keys.
FastWAM's `__post_init__` installs a synthetic single-image default
(`observation.images.image` at full `image_size` width). For datasets
with one or more separately-named cameras (e.g. `observation.images.top`,
`observation.images.wrist`), this hook — invoked by `make_policy` once the
dataset metadata is known — replaces that default with the actual camera
keys, each declared at the policy's native per-camera resolution
(`image_size[0]` x `image_size[1] // num_cameras`). The accompanying
resize step in `make_fastwam_pre_post_processors` resizes raw frames to
match, so heterogeneous source resolutions (e.g. 480x640) are supported.
"""
image_keys = sorted(
key
for key, feature in dataset_features.items()
if key.startswith("observation.images.") and feature.get("dtype") in ("video", "image")
)
if not image_keys:
return
height, total_width = self.image_size
per_cam_width = total_width // len(image_keys)
new_inputs: dict[str, PolicyFeature] = {
key: PolicyFeature(type=FeatureType.VISUAL, shape=(3, height, per_cam_width))
for key in image_keys
}
if self.proprio_dim is not None and OBS_STATE in dataset_features:
new_inputs[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.proprio_dim,))
self.input_features = new_inputs
self.validate_features()
def validate_features(self) -> None:
if self.action_dim <= 0:
raise ValueError(f"`action_dim` must be positive, got {self.action_dim}.")
if self.action_horizon <= 0:
raise ValueError(f"`action_horizon` must be positive, got {self.action_horizon}.")
if self.n_action_steps > self.action_horizon:
raise ValueError("`n_action_steps` cannot exceed `action_horizon`.")
if self.action_video_freq_ratio <= 0:
raise ValueError(
f"`action_video_freq_ratio` must be positive, got {self.action_video_freq_ratio}."
)
# Video frames are subsampled by action_video_freq_ratio; the resulting model frame
# count must satisfy T % 4 == 1 for the VAE temporal tokenization (mirrors the
# original FastWAM dataset asserts).
if (self.num_video_frames - 1) % self.action_video_freq_ratio != 0:
raise ValueError(
f"`num_video_frames - 1` ({self.num_video_frames - 1}) must be divisible by "
f"`action_video_freq_ratio` ({self.action_video_freq_ratio})."
)
if ((self.num_video_frames - 1) // self.action_video_freq_ratio) % 4 != 0:
raise ValueError(
f"Subsampled video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio}) "
"must be divisible by 4 for VAE tokenization (i.e. model_video_frames % 4 == 1)."
)
if self.action_horizon % ((self.num_video_frames - 1) // self.action_video_freq_ratio) != 0:
raise ValueError(
f"`action_horizon` ({self.action_horizon}) must be divisible by the number of "
f"video transitions ({(self.num_video_frames - 1) // self.action_video_freq_ratio})."
)
if not self.image_features:
raise ValueError("FastWAM requires at least one image feature.")
if self.action_feature is None:
raise ValueError("FastWAM requires `action` in output_features.")
action_shape = tuple(self.action_feature.shape)
if action_shape != (self.action_dim,):
raise ValueError(
f"FastWAM action feature shape must be ({self.action_dim},), got {action_shape}."
)
if self.proprio_dim is not None:
state_feature = self.robot_state_feature
if state_feature is None:
raise ValueError("FastWAM requires `observation.state` when `proprio_dim` is set.")
state_shape = tuple(state_feature.shape)
if state_shape != (self.proprio_dim,):
raise ValueError(
f"FastWAM state feature shape must be ({self.proprio_dim},), got {state_shape}."
)
height, width = self.image_size
image_width_sum = 0
for name, feature in self.image_features.items():
shape = tuple(feature.shape)
if len(shape) != 3 or shape[0] != 3:
raise ValueError(f"FastWAM image feature `{name}` must have shape (3, H, W), got {shape}.")
if shape[1] != height:
raise ValueError(f"FastWAM image feature `{name}` height must be {height}, got {shape[1]}.")
image_width_sum += shape[2]
if image_width_sum != width:
raise ValueError(f"FastWAM image feature widths must sum to {width}, got {image_width_sum}.")
@property
def model_video_frames(self) -> int:
"""Number of video frames the model actually operates on, after subsampling the
raw `num_video_frames` window by `action_video_freq_ratio` (e.g. 33 -> 9)."""
return (self.num_video_frames - 1) // self.action_video_freq_ratio + 1
@property
def observation_delta_indices(self) -> list[int]:
# Load the video frames the model is supervised on: the future window subsampled by
# action_video_freq_ratio (e.g. [0, 4, 8, ..., 32] -> 9 frames). Each video frame is
# thus `action_video_freq_ratio` actions apart, while actions load at the full rate
# (`action_delta_indices` = range(action_horizon)). Returning None would load only the
# current frame, making the video target a static repeat (degenerate supervision).
return list(range(0, self.num_video_frames, self.action_video_freq_ratio))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.action_horizon))
@property
def reward_delta_indices(self) -> None:
return None
@@ -0,0 +1,540 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any
import torch
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import OBS_STATE
from .configuration_fastwam import FastWAMConfig
from .modular_fastwam import ActionDiT, FastWAM, MoT
from .wan_components import (
build_wan_tokenizer,
load_pretrained_wan_text_encoder,
load_pretrained_wan_vae,
)
from .wan_video_dit import WanVideoDiT
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
# eval episode's action chunks through `infer_joint` so the predicted video latents are
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
_DEBUG_PRED_RATE_DIV = 1
class FastWAMPolicy(PreTrainedPolicy):
"""LeRobot policy wrapper for FastWAM.
Args:
config (FastWAMConfig): FastWAM policy configuration.
dataset_stats (dict[str, dict[str, Tensor]] | None): Optional LeRobot
dataset statistics passed by the training/evaluation stack.
"""
config_class = FastWAMConfig
name = "fastwam"
def __init__(
self,
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
):
# `make_policy`/`from_pretrained` forward extra kwargs (e.g. `dataset_meta`); the
# dataset feature metadata is already applied to `config` by make_policy upstream,
# so we accept and ignore them, matching the other LeRobot policies.
super().__init__(config, dataset_stats)
config.validate_features()
self.config = config
self.dataset_stats = dataset_stats
self.model = self._build_core_model(config)
if config.freeze_video_expert and getattr(self.model, "video_expert", None) is not None:
# Freeze the ~5B Wan video expert; get_optim_params filters on requires_grad,
# so its params drop out of the optimizer (and DDP skips them).
self.model.video_expert.requires_grad_(False)
# The transformer blocks are re-parented onto the MoTLayers (single FSDP owner), so
# `video_expert.requires_grad_` no longer reaches them — freeze them via the layers.
mot = getattr(self.model, "mot", None)
if mot is not None and getattr(mot, "layers", None) is not None:
for layer in mot.layers:
if "video" in layer.blocks:
layer.blocks["video"].requires_grad_(False)
self.reset()
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
# counts only eval-rollout resets (one per episode), not this __init__ one.
self._debug_constructed = True
self._debug_episode_index = -1
self._debug_seen_tasks: set[str] = set()
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video: list | None = None
self._debug_pairs: list = []
@classmethod
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
"""Shape-aware load that supports cross-embodiment fine-tuning.
`safetensors.load_model(strict=False)` ignores missing/unexpected keys but
still raises on a shape mismatch for a shared key. When fine-tuning from a
checkpoint trained on a different embodiment (e.g. the LIBERO 7-DoF / 8-dim
checkpoint adapted to a 6-DoF / 6-dim arm), the action encoder/head and
proprio encoder legitimately differ in shape. With `strict=False` we drop
only those shape-mismatched tensors — leaving them at their freshly
initialized values — and load every compatible tensor. With `strict=True`
the standard exact-match loader is used.
"""
from safetensors import safe_open
model_state_dict = model.state_dict()
mismatched = []
with safe_open(model_file, framework="pt") as f:
checkpoint_keys = list(f.keys())
for key in checkpoint_keys:
if key in model_state_dict and tuple(model_state_dict[key].shape) != tuple(
f.get_slice(key).get_shape()
):
mismatched.append(key)
if not mismatched:
return super()._load_as_safetensor(model, model_file, map_location, strict)
if strict:
raise RuntimeError(
f"FastWAM: {len(mismatched)} checkpoint tensors have a shape mismatch under "
f"strict=True: {mismatched}"
)
from safetensors.torch import load_file
logging.warning(
"FastWAM cross-embodiment load: reinitializing %d shape-mismatched tensor(s), keeping "
"every compatible weight: %s",
len(mismatched),
mismatched,
)
state_dict = load_file(model_file, device="cpu")
for key in mismatched:
state_dict.pop(key, None)
model.load_state_dict(state_dict, strict=False)
if map_location and map_location != "cpu":
model.to(map_location)
return model
def get_optim_params(self) -> list[Tensor]:
# Return the trainable tensors directly (a single param group). The optimizer
# builder wraps these in a param group; returning a bare {"params": [...]} dict
# instead would make `list(...)` yield the key string "params".
params = (
list(self.model.dit.parameters()) if hasattr(self.model, "dit") else list(self.model.parameters())
)
proprio_encoder = getattr(self.model, "proprio_encoder", None)
if proprio_encoder is not None:
params.extend(list(proprio_encoder.parameters()))
return [p for p in params if p.requires_grad]
def reset(self) -> None:
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
# true-vs-pred video if it was a captured one (pairs accumulate only while
# capturing), then reset per-episode capture state.
if getattr(self, "_debug_constructed", False):
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
self._save_debug_video()
self._debug_episode_index += 1
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video = None
self._debug_pairs = []
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
`FastWAM.build_inputs` consumes (`video`, `action`, `context`/`context_mask`,
per-frame `proprio`).
The LeRobot training loop passes raw `observation.images.*`, a single-step
`observation.state` `[B, D]`, `action`, and a language `task` string. We do
only the translation `build_inputs` can't: stack the camera frames into a
video, encode the prompt with the (frozen) text encoder (mirroring inference,
so language-conditioned datasets need no precomputed context), and give proprio
the per-frame axis `build_inputs` indexes. All shape/presence validation is
left to `build_inputs`, the single authority on the contract.
"""
sample = dict(batch)
if "video" not in sample:
sample["video"] = _stack_video_from_images(batch, self.config)
if "context" not in sample or "context_mask" not in sample:
prompt = _prompt_from_batch(batch=batch, config=self.config)
if prompt is None:
raise KeyError(
"FastWAM training requires a `task`/`prompt` to encode text context, "
"or precomputed `context`/`context_mask` in the batch."
)
sample["context"], sample["context_mask"] = self.model.encode_prompt(prompt)
if self.config.proprio_dim is not None and "proprio" not in sample:
state = sample.get(OBS_STATE)
if state is not None:
# LeRobot gives a single-step state [B, D]; build_inputs expects
# per-frame [B, T, D] and uses frame 0, so add a T=1 axis.
sample["proprio"] = state.unsqueeze(1) if state.ndim == 2 else state
return sample
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
"""Compute FastWAM training loss for a LeRobot batch.
Args:
batch (dict[str, Tensor]): Batch containing FastWAM-ready keys
(`video`, `action`, `context`, `context_mask`) or LeRobot keys
that can be adapted (`observation.images.*`, `observation.state`,
`action`, `action_is_pad`).
Returns:
tuple[Tensor, dict[str, Any]]: The scalar loss to backprop, and a dict of
logging metrics (e.g. `loss_video`, `loss_action`) — the `(loss, output_dict)`
contract the LeRobot training loop expects.
"""
sample = self._batch_to_training_sample(batch)
loss, metrics = self.model.training_loss(sample)
return loss, dict(metrics or {})
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **_: Any) -> Tensor:
"""Predict a chunk of actions from the current FastWAM observation.
Args:
batch (dict[str, Tensor]): Inference batch with `input_image` or
image observation keys, plus `context/context_mask` or `prompt`.
Returns:
Tensor: Action chunk with shape `[B, action_horizon, action_dim]`.
"""
self.eval()
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
batch_size = _infer_kwargs_batch_size(infer_kwargs)
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
# run the joint video+action path so the predicted video is VAE-decoded; stash it
# so select_action can pair each predicted frame with the real obs that follows.
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
out = self.model.infer_joint(
**infer_kwargs,
num_video_frames=self.config.model_video_frames,
test_action_with_infer_action=False,
)
# The decoded rollout has model_video_frames frames spanning the full
# action_horizon (action_video_freq_ratio actions per frame); the per-step
# pairing indexes into it, so keep all frames.
self._debug_last_video = out["video"]
action = _action_from_model_output(out)
elif batch_size == 1:
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
else:
action = torch.cat(
[
_action_from_model_output(
self.model.infer_action(
**_slice_infer_kwargs(infer_kwargs, index=i, batch_size=batch_size)
)
)
for i in range(batch_size)
],
dim=0,
)
return action.to(device=batch_device(batch), dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
self.eval()
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
# get the first episode of every task).
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
self._debug_episode_started = True
task = self._debug_task_name(batch)
if task not in self._debug_seen_tasks:
self._debug_seen_tasks.add(task)
self._debug_capturing = True
self._debug_episode_task = task
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
if capturing:
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
if capturing:
self._debug_capture_pair(batch)
self._debug_step_in_chunk += 1
return self._action_queue.popleft()
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
@staticmethod
def _debug_task_name(batch: dict[str, Any]) -> str:
task = batch.get("task")
if isinstance(task, (list, tuple)):
task = task[0] if task else None
return str(task) if task else "no_task"
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
video = getattr(self, "_debug_last_video", None)
if not video:
return
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
idx = min(
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
len(video) - 1,
)
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
self._debug_pairs.append(pair)
@staticmethod
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
from PIL import ImageDraw
draw = ImageDraw.Draw(pair)
draw.text((3, 3), "true", fill=(255, 255, 0))
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
@staticmethod
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
arr = ((image.detach().float().clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
def _debug_hstack(left, right):
from PIL import Image
if right.height != left.height:
right = right.resize((round(right.width * left.height / right.height), left.height))
canvas = Image.new("RGB", (left.width + right.width, left.height))
canvas.paste(left, (0, 0))
canvas.paste(right, (left.width, 0))
return canvas
def _save_debug_video(self) -> None:
import re
import numpy as np
from lerobot.utils.io_utils import write_video
pairs = getattr(self, "_debug_pairs", None)
if not pairs:
return
out_dir = Path("outputs/fastwam_debug")
out_dir.mkdir(parents=True, exist_ok=True)
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
write_video(path, frames, fps=30)
logging.info(
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
)
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
"""Build the FastWAM core for training / inference.
Only the trainable parts (the MoT DiT and the proprio encoder) are
materialized empty here and then filled from the policy's
`model.safetensors` by the base `from_pretrained`. The *frozen* Wan2.2 VAE
and UMT5 text encoder are loaded with their real weights from the
`Wan-AI/Wan2.2-TI2V-5B-Diffusers` repo (cached in the HF cache, shared
across checkpoints) and are intentionally excluded from `model.safetensors`
— see `FastWAM.__init__`. The tokenizer comes from `google/umt5-xxl`.
"""
dtype = _dtype_from_name(config.torch_dtype)
device = config.device
video_expert = WanVideoDiT(**config.video_dit_config).to(device=device, dtype=dtype)
action_expert = ActionDiT(**config.action_dit_config).to(device=device, dtype=dtype)
mot = MoT(
mixtures={"video": video_expert, "action": action_expert},
mot_checkpoint_mixed_attn=config.mot_checkpoint_mixed_attn,
)
text_encoder = (
load_pretrained_wan_text_encoder(torch_dtype=dtype, device=device)
if config.load_text_encoder
else None
)
return FastWAM(
video_expert=video_expert,
action_expert=action_expert,
mot=mot,
vae=load_pretrained_wan_vae(torch_dtype=dtype, device=device),
text_encoder=text_encoder,
tokenizer=build_wan_tokenizer(tokenizer_max_len=config.tokenizer_max_len),
text_dim=int(config.video_dit_config["text_dim"]),
proprio_dim=config.proprio_dim,
device=device,
torch_dtype=dtype,
video_train_shift=float(config.video_scheduler["train_shift"]),
video_infer_shift=float(config.video_scheduler["infer_shift"]),
video_num_train_timesteps=int(config.video_scheduler["num_train_timesteps"]),
action_train_shift=float(config.action_scheduler["train_shift"]),
action_infer_shift=float(config.action_scheduler["infer_shift"]),
action_num_train_timesteps=int(config.action_scheduler["num_train_timesteps"]),
loss_lambda_video=float(config.loss["lambda_video"]),
loss_lambda_action=float(config.loss["lambda_action"]),
)
def _batch_to_infer_kwargs(batch: dict[str, Tensor], config: FastWAMConfig) -> dict[str, Any]:
return {
"prompt": _prompt_from_batch(batch=batch, config=config),
"input_image": _input_image_from_batch(batch, config),
"action_horizon": config.action_horizon,
"proprio": batch.get("proprio", batch.get(OBS_STATE)),
"context": batch.get("context"),
"context_mask": batch.get("context_mask"),
"negative_prompt": batch.get("negative_prompt", config.negative_prompt),
"text_cfg_scale": float(batch.get("text_cfg_scale", config.text_cfg_scale)),
"num_inference_steps": int(batch.get("num_inference_steps", config.num_inference_steps)),
"sigma_shift": batch.get("sigma_shift", config.sigma_shift),
"seed": batch.get("seed", config.inference_seed),
"rand_device": batch.get("rand_device", config.rand_device),
"tiled": bool(batch.get("tiled", config.tiled)),
}
def _prompt_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Any:
prompt = batch.get("prompt")
if prompt is not None:
return prompt
task = batch.get("task")
if task is None:
return None
if isinstance(task, str):
return config.prompt_template.format(task=task)
if isinstance(task, (list, tuple)):
return [config.prompt_template.format(task=str(item)) for item in task]
return config.prompt_template.format(task=str(task))
def _action_from_model_output(output: Any) -> Tensor:
action = output["action"] if isinstance(output, dict) else output
if action.ndim == 2:
action = action.unsqueeze(0)
return action
def _infer_kwargs_batch_size(infer_kwargs: dict[str, Any]) -> int:
image = infer_kwargs["input_image"]
if not isinstance(image, Tensor):
raise TypeError(f"`input_image` must be a tensor, got {type(image).__name__}.")
if image.ndim == 3:
return 1
if image.ndim == 4:
return int(image.shape[0])
raise ValueError(f"`input_image` must be [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
def _slice_infer_kwargs(infer_kwargs: dict[str, Any], *, index: int, batch_size: int) -> dict[str, Any]:
return {
key: _slice_infer_value(value, index=index, batch_size=batch_size)
for key, value in infer_kwargs.items()
}
def _slice_infer_value(value: Any, *, index: int, batch_size: int) -> Any:
if isinstance(value, Tensor) and value.ndim > 0 and value.shape[0] == batch_size:
return value[index : index + 1]
if isinstance(value, (list, tuple)) and len(value) == batch_size:
return value[index]
return value
def _dtype_from_name(name: str) -> torch.dtype:
dtype_map = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}
if name not in dtype_map:
raise ValueError(f"Unsupported torch dtype `{name}`.")
return dtype_map[name]
def batch_device(batch: dict[str, Any]) -> torch.device:
for value in batch.values():
if isinstance(value, Tensor):
return value.device
return torch.device("cpu")
def _stack_video_from_images(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
# Exclude the `*_is_pad` companion tensors that delta-timestamp loading adds alongside
# each camera (shape [B, T]); they share the `observation.images.` prefix but are not frames.
image_keys = sorted(k for k in batch if k.startswith("observation.images.") and not k.endswith("_is_pad"))
if not image_keys:
raise KeyError("FastWAM batch must contain `video` or `observation.images.*` keys.")
images = [batch[key] for key in image_keys]
# Cameras concatenate along width (last dim) in both the single-frame and temporal case.
image = torch.cat(images, dim=-1) if len(images) > 1 else images[0]
if image.ndim == 4:
# [B, C, H, W]: a single frame (e.g. the live eval observation) -> repeat across time.
image = image.unsqueeze(2).repeat(1, 1, config.model_video_frames, 1, 1)
elif image.ndim == 5:
# [B, T, C, H, W]: temporal stack from delta-timestamp loading -> [B, C, T, H, W].
image = image.permute(0, 2, 1, 3, 4)
else:
raise ValueError(f"Expected image batch [B,C,H,W] or temporal [B,T,C,H,W], got {tuple(image.shape)}.")
return image
def _input_image_from_batch(batch: dict[str, Tensor], config: FastWAMConfig) -> Tensor:
if "input_image" in batch:
return _prepare_infer_image(batch["input_image"], config)
video = batch.get("video")
if video is None:
video = _stack_video_from_images(batch, config)
if video.ndim == 5:
return _prepare_infer_image(video[:, :, 0], config)
if video.ndim == 4:
return _prepare_infer_image(video, config)
raise ValueError(f"Cannot build input image from tensor with shape {tuple(video.shape)}.")
def _prepare_infer_image(image: Tensor, config: FastWAMConfig) -> Tensor:
if image.ndim == 3:
image = image.unsqueeze(0)
if image.ndim != 4:
raise ValueError(f"Expected image tensor [B,C,H,W] or [C,H,W], got {tuple(image.shape)}.")
target_h, target_w = config.image_size
if tuple(image.shape[-2:]) != (target_h, target_w):
raise ValueError(
"FastWAM policy expects preprocessed image tensors with shape "
f"[B,C,{target_h},{target_w}], got {tuple(image.shape)}. "
"Run the FastWAM preprocessor before calling the policy."
)
return image
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,183 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import (
ActionProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
ImageCropResizeProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from .configuration_fastwam import FastWAMConfig
@dataclass
@ProcessorStepRegistry.register(name="fastwam_image_crop_resize_processor")
class FastWAMImageCropResizeProcessorStep(ImageCropResizeProcessorStep):
"""`ImageCropResizeProcessorStep` that tolerates a leading temporal/batch stack.
FastWAM loads a per-camera video stack, so image observations arrive as
``[B, T, C, H, W]``. torchvision's crop/resize only accept ``[..., C, H, W]`` with a
single leading batch dim (resize raises on 5-D input), so we flatten any leading
dims into the batch, apply the base 4-D crop/resize, then restore the leading shape.
Crop/resize params and feature-shape bookkeeping are inherited unchanged.
"""
def observation(self, observation: dict) -> dict:
# Delta-timestamp video loading adds `<image_key>_is_pad` boolean masks ([B, T]) that share
# the `observation.images.` prefix but are padding flags, not frames. The base crop/resize
# matches on the `"image"` substring, so set these aside and restore them untouched rather
# than letting it try to resize a mask.
pad_keys = {key: value for key, value in observation.items() if "_is_pad" in key}
leads: dict[str, tuple] = {}
flat_input = {key: value for key, value in observation.items() if key not in pad_keys}
for key, img in list(flat_input.items()):
if "image" in key and torch.is_tensor(img) and img.ndim > 4:
leads[key] = tuple(img.shape[:-3])
flat_input[key] = img.reshape(-1, *img.shape[-3:])
processed = super().observation(flat_input)
out = dict(processed)
for key, lead in leads.items():
im = processed[key]
out[key] = im.reshape(*lead, *im.shape[-3:])
out.update(pad_keys)
return out
@dataclass
@ProcessorStepRegistry.register(name="fastwam_action_toggle_processor")
class FastWAMActionToggleProcessorStep(ActionProcessorStep):
"""Apply FastWAM LIBERO toggle semantics to configured action dimensions."""
toggle_dimensions: list[int]
def action(self, action: PolicyAction) -> PolicyAction:
if not self.toggle_dimensions:
return action
processed_action = action.clone()
action_dim = int(processed_action.shape[-1])
for dim in self.toggle_dimensions:
resolved_dim = dim if dim >= 0 else action_dim + dim
if resolved_dim < 0 or resolved_dim >= action_dim:
raise ValueError(
f"FastWAM action toggle dimension {dim} is out of bounds for action dim {action_dim}."
)
value = processed_action[..., resolved_dim]
value = value * 2.0 - 1.0
processed_action[..., resolved_dim] = torch.sign(-value)
return processed_action
def get_config(self) -> dict[str, Any]:
return {"toggle_dimensions": self.toggle_dimensions}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def make_fastwam_pre_post_processors(
config: FastWAMConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Create LeRobot pre- and post-processing pipelines for FastWAM.
Args:
config (FastWAMConfig): Policy configuration controlling device and
normalization feature metadata.
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): Optional
LeRobot dataset statistics used by normalization processors.
Returns:
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: Input and
output processor pipelines discoverable by LeRobot.
"""
# force visual stats to be mean 0.5 and std 0.5 to map [0, 1] data to [-1, 1]
normalization_stats: dict[str, dict[str, Any]] = dict(dataset_stats or {})
for key, feature in config.input_features.items():
if feature.type != FeatureType.VISUAL:
continue
channels = int(feature.shape[0])
normalization_stats[key] = {
"mean": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
"std": torch.full((channels, 1, 1), 0.5, dtype=torch.float32),
}
# resize visual inputs to match model expected input size, if necessary
visual_shapes = [
feature.shape for feature in config.input_features.values() if feature.type == FeatureType.VISUAL
]
resize_steps = []
if visual_shapes:
target_hw = (int(visual_shapes[0][1]), int(visual_shapes[0][2]))
# FastWAM-aware resize: tolerates the leading temporal dim of the video stack.
resize_steps.append(FastWAMImageCropResizeProcessorStep(resize_size=target_hw))
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
*resize_steps,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=normalization_stats,
device=config.device,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=normalization_stats,
),
]
if config.toggle_action_dimensions:
output_steps.append(
FastWAMActionToggleProcessorStep(toggle_dimensions=config.toggle_action_dimensions)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -0,0 +1,25 @@
# Wan2.2 Upstream Subset
This directory contains the trimmed subset of the official Wan2.2 source tree used by FastWAM.
- Upstream repository: https://github.com/Wan-Video/Wan2.2
- Upstream commit: `42bf4cfaa384bc21833865abc2f9e6c0e67233dc`
- License: Apache-2.0, matching the license in `LICENSE.txt` from the upstream repository
Copied files:
- `wan/modules/attention.py`
- `wan/modules/model.py`
- `wan/modules/__init__.py`
- `wan/utils/fm_solvers.py`
- `wan/utils/__init__.py`
This subset now only backs FastWAM's **custom MoT video DiT**. The Wan2.2 VAE,
UMT5 text encoder, and tokenizer are no longer vendored — they come from
`diffusers.AutoencoderKLWan`, `transformers.UMT5EncoderModel`, and
`transformers.AutoTokenizer` (see `../wan_adapters.py` and `../wan_components.py`).
Current FastWAM adapters that directly reuse this vendored subset:
- `../wan_video_dit.py` builds on `wan.modules.model` (`sinusoidal_embedding_1d`, `rope_params`, `rope_apply`, …) and `wan.modules.attention.flash_attention`.
- `../modular_fastwam.py` reuses `wan.utils.fm_solvers.get_sampling_sigmas` for Wan-compatible inference timesteps.
@@ -0,0 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
__all__ = [
"WanModel",
"flash_attention",
]
@@ -0,0 +1,183 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
"flash_attention",
"attention",
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == "cuda" and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens, strict=False)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens, strict=False)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens, strict=False)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.", stacklevel=2)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic,
)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
.cumsum(0, dtype=torch.int32)
.to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.",
stacklevel=2,
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
)
out = out.transpose(1, 2).contiguous()
return out
@@ -0,0 +1,519 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
__all__ = ["WanModel"]
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@torch.amp.autocast("cuda", enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast("cuda", enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size,
)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(
self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6
):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with torch.amp.autocast("cuda", dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs
)
with torch.amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[2].squeeze(2)
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
with torch.amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[5].squeeze(2)
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with torch.amp.autocast("cuda", dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
return x
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
_no_split_modules = ["WanAttentionBlock"]
@register_to_config
def __init__(
self,
model_type="t2v",
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList(
[
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
for _ in range(num_layers)
]
)
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == "i2v":
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast("cuda", dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()
)
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
)
# arguments
kwargs = {
"e": e0,
"seq_lens": seq_lens,
"grid_sizes": grid_sizes,
"freqs": self.freqs,
"context": context,
"context_lens": context_lens,
}
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist(), strict=False):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size, strict=False)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
@@ -0,0 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .fm_solvers import get_sampling_sigmas
__all__ = [
"get_sampling_sigmas",
]
@@ -0,0 +1,9 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = shift * sigma / (1 + (shift - 1) * sigma)
return sigma
@@ -0,0 +1,111 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
from diffusers import AutoencoderKLWan
class WanVideoVAE38(torch.nn.Module):
"""FastWAM VAE contract over `diffusers.AutoencoderKLWan` (Wan2.2-TI2V-5B).
16x spatial / 4x temporal compression, 48 latent channels. diffusers'
`AutoencoderKLWan` returns *raw* latents (it does not apply `latents_mean`/
`latents_std`), so `encode`/`decode` here apply the same standardization the
Wan reference uses — `(latents - mean) / std` — done in fp32 for stability.
`encode` uses the deterministic posterior mode, matching the original VAE
which returned the latent mean `mu`.
"""
upsampling_factor = 16
temporal_downsample_factor = 4
z_dim = 48
def __init__(
self,
dtype: torch.dtype = torch.float32,
device: str | torch.device = "cuda",
*,
pretrained: AutoencoderKLWan,
) -> None:
super().__init__()
# The Wan2.2 VAE is a fixed pretrained model — it is never trained from scratch,
# so a real `AutoencoderKLWan` (with weights) must always be supplied (loaded from
# the diffusers repo by `load_pretrained_wan_vae`). No random/offline build path.
self.vae = pretrained.to(device=device, dtype=dtype)
# Read the standardization stats from the VAE's own config (diffusers populates
# these from vae/config.json) — single source of truth, no local copy. diffusers'
# encode/decode return *raw* latents, so we apply (latent - mean) / std ourselves.
# Non-persistent: kept out of state_dict.
self.register_buffer(
"latents_mean",
torch.tensor(self.vae.config.latents_mean).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
self.register_buffer(
"latents_std",
torch.tensor(self.vae.config.latents_std).view(1, self.z_dim, 1, 1, 1),
persistent=False,
)
def _device_dtype(self) -> tuple[torch.device, torch.dtype]:
param = next(self.vae.parameters())
return param.device, param.dtype
def encode(
self,
videos: list[torch.Tensor] | torch.Tensor,
device: str | torch.device | None = None,
tiled: bool = False,
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE encoding is not supported by the FastWAM adapter.")
if isinstance(videos, (list, tuple)):
videos = torch.stack(list(videos))
dev, dtype = self._device_dtype()
mu = self.vae.encode(videos.to(device=dev, dtype=dtype)).latent_dist.mode().float()
mean = self.latents_mean.float().to(mu.device)
std = self.latents_std.float().to(mu.device)
return (mu - mean) / std
def decode(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
device: str | torch.device | None = None,
tiled: bool = False,
tile_size: tuple[int, int] = (34, 34),
tile_stride: tuple[int, int] = (18, 16),
) -> torch.Tensor:
del device, tile_size, tile_stride
if tiled:
raise NotImplementedError("Tiled Wan2.2 VAE decoding is not supported by the FastWAM adapter.")
if isinstance(hidden_states, (list, tuple)):
hidden_states = torch.stack(list(hidden_states))
dev, dtype = self._device_dtype()
z = hidden_states.float()
z = z * self.latents_std.float().to(z.device) + self.latents_mean.float().to(z.device)
out = self.vae.decode(z.to(device=dev, dtype=dtype)).sample
return out.float().clamp_(-1.0, 1.0)
__all__ = ["WanVideoVAE38"]
@@ -0,0 +1,172 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, UMT5EncoderModel
if TYPE_CHECKING:
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
from diffusers import AutoencoderKLWan
from .wan_adapters import WanVideoVAE38
from .wan_video_dit import WanVideoDiT
logger = logging.getLogger(__name__)
# The custom MoT video DiT still ships in the original (non-diffusers) Wan2.2
# repo as sharded `diffusion_pytorch_model*.safetensors`; the VAE and UMT5 text
# encoder come from the diffusers conversion. Tokenizer is the stock UMT5 one.
WAN_DIT_PATTERN = "diffusion_pytorch_model*.safetensors"
WAN_T5_TOKENIZER = "google/umt5-xxl"
WAN22_DIFFUSERS_MODEL_ID = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
class WanTextEncoder(torch.nn.Module):
"""FastWAM text-encoder contract over `transformers.UMT5EncoderModel`.
Exposes `.dim` (hidden size) and `forward(ids, mask) -> [B, L, dim]`, matching
the call in `FastWAM.encode_prompt`.
"""
def __init__(
self,
dtype: torch.dtype = torch.bfloat16,
device: str | torch.device = "cuda",
*,
pretrained: torch.nn.Module,
) -> None:
super().__init__()
# UMT5-XXL is a fixed pretrained encoder — never trained from scratch, so a real
# `UMT5EncoderModel` (with weights) must always be supplied (loaded from the
# diffusers repo by `load_pretrained_wan_text_encoder`). No random/offline build.
self.model = pretrained.to(device=device, dtype=dtype)
self.dim = int(self.model.config.d_model)
def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
return self.model(input_ids=ids, attention_mask=mask.long()).last_hidden_state
class WanTokenizer:
"""UMT5 tokenizer wrapper returning `(input_ids, attention_mask)` like the
FastWAM call site expects."""
def __init__(self, name: str = WAN_T5_TOKENIZER, seq_len: int = 512) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.seq_len = int(seq_len)
def __call__(
self,
sequence: str | Sequence[str],
return_mask: bool = False,
add_special_tokens: bool = True,
**_: Any,
):
if isinstance(sequence, str):
sequence = [sequence]
out = self.tokenizer(
list(sequence),
padding="max_length",
truncation=True,
max_length=self.seq_len,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
if return_mask:
return out.input_ids, out.attention_mask
return out.input_ids
def build_wan_tokenizer(*, tokenizer_max_len: int) -> WanTokenizer:
return WanTokenizer(name=WAN_T5_TOKENIZER, seq_len=int(tokenizer_max_len))
def load_pretrained_wan_vae(*, torch_dtype: torch.dtype, device: str) -> WanVideoVAE38:
"""Load real Wan2.2 VAE weights from the diffusers repo (offline base creation)."""
vae = AutoencoderKLWan.from_pretrained(WAN22_DIFFUSERS_MODEL_ID, subfolder="vae", torch_dtype=torch_dtype)
return WanVideoVAE38(dtype=torch_dtype, device=device, pretrained=vae)
def load_pretrained_wan_text_encoder(*, torch_dtype: torch.dtype, device: str) -> WanTextEncoder:
"""Load real UMT5-XXL encoder weights from the diffusers repo (offline base creation)."""
encoder = UMT5EncoderModel.from_pretrained(
WAN22_DIFFUSERS_MODEL_ID, subfolder="text_encoder", torch_dtype=torch_dtype
)
return WanTextEncoder(dtype=torch_dtype, device=device, pretrained=encoder)
def resolve_wan_dit_paths(
model_id_or_path: str | Path,
*,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
) -> list[Path]:
"""Resolve the custom MoT DiT shards from the original Wan2.2 repo or a local dir."""
path = Path(model_id_or_path).expanduser()
if path.is_dir():
return sorted(path.glob(WAN_DIT_PATTERN))
snapshot_path = snapshot_download(
repo_id=str(model_id_or_path),
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
allow_patterns=[WAN_DIT_PATTERN],
)
return sorted(Path(snapshot_path).glob(WAN_DIT_PATTERN))
def load_wan_video_dit(
paths: list[str | Path],
*,
dit_config: dict[str, Any],
torch_dtype: torch.dtype,
device: str,
) -> WanVideoDiT:
model = WanVideoDiT(**dit_config)
state_dict = _read_wan_dit_safetensors(paths)
model.load_state_dict(state_dict, strict=False)
return model.to(device=device, dtype=torch_dtype)
def _read_wan_dit_safetensors(paths: list[str | Path]) -> dict[str, torch.Tensor]:
state_dict = {}
for path in paths:
state_dict.update(load_file(str(path), device="cpu"))
return state_dict
__all__ = [
"WAN22_DIFFUSERS_MODEL_ID",
"WAN_DIT_PATTERN",
"WAN_T5_TOKENIZER",
"WanTextEncoder",
"WanTokenizer",
"build_wan_tokenizer",
"load_pretrained_wan_text_encoder",
"load_pretrained_wan_vae",
"load_wan_video_dit",
"resolve_wan_dit_paths",
]
@@ -0,0 +1,813 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as functional
from einops import rearrange
from .wan.modules.model import (
WanAttentionBlock,
WanLayerNorm,
WanModel,
WanRMSNorm,
rope_apply,
rope_params,
sinusoidal_embedding_1d,
)
from .wan.utils.fm_solvers import get_sampling_sigmas
logger = logging.getLogger(__name__)
def create_custom_forward(module):
def custom_forward(*inputs, **kwargs):
return module(*inputs, **kwargs)
return custom_forward
def gradient_checkpoint_forward(
model,
use_gradient_checkpointing,
*args,
**kwargs,
):
if use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
create_custom_forward(model),
*args,
**kwargs,
use_reentrant=False,
)
else:
model_output = model(*args, **kwargs)
return model_output
def fastwam_masked_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
ctx_mask: torch.Tensor | None = None,
fp32_attention: bool = True,
) -> torch.Tensor:
"""FastWAM masked attention wrapper for MoT masks and CPU test coverage.
The official Wan attention implementation is still used as the source of
the projection/norm modules. This wrapper only replaces the final attention
kernel because FastWAM needs explicit boolean masks for video/action MoT
routing, while the upstream FlashAttention path accepts sequence lengths
but not arbitrary [query, key] masks.
"""
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
if fp32_attention:
q = q.float()
k = k.float()
v = v.float()
else:
q = q.to(dtype=v.dtype)
k = k.to(dtype=v.dtype)
x = functional.scaled_dot_product_attention(q, k, v, attn_mask=ctx_mask)
return rearrange(x, "b n s d -> b s (n d)", n=num_heads)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
def _get_wan_sampling_sigmas(num_inference_steps: int, shift: float) -> list[float]:
return get_sampling_sigmas(num_inference_steps, shift)
class WanContinuousFlowMatchScheduler:
"""Continuous-time Flow-Matching scheduler with shift-based Wan sampling."""
def __init__(self, num_train_timesteps: int = 1000, shift: float = 5.0, eps: float = 1e-10):
if num_train_timesteps <= 0:
raise ValueError(f"`num_train_timesteps` must be positive, got {num_train_timesteps}")
if shift <= 0:
raise ValueError(f"`shift` must be positive, got {shift}")
self.num_train_timesteps = int(num_train_timesteps)
self.shift = float(shift)
self.eps = float(eps)
self._y_min, self._weight_norm_const = self._precompute_training_weight_stats()
@staticmethod
def _phi(u: torch.Tensor, shift: float) -> torch.Tensor:
return shift * u / (1.0 + (shift - 1.0) * u)
def _precompute_training_weight_stats(self) -> tuple[float, float]:
steps = self.num_train_timesteps
u_grid = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float64)[:-1]
t_grid = self._phi(u_grid, self.shift) * float(steps)
y_grid = torch.exp(-2.0 * ((t_grid - (steps / 2.0)) / steps) ** 2)
y_min = float(y_grid.min().item())
y_shifted_grid = y_grid - y_min
norm_const = float(y_shifted_grid.mean().item())
return y_min, norm_const
def sample_training_t(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if batch_size <= 0:
raise ValueError(f"`batch_size` must be positive, got {batch_size}")
u = torch.rand((batch_size,), device=device, dtype=torch.float32)
sigma = self._phi(u, self.shift)
timestep = sigma * float(self.num_train_timesteps)
return timestep.to(dtype=dtype)
def training_weight(self, timestep: torch.Tensor) -> torch.Tensor:
t = timestep.to(dtype=torch.float32)
steps = float(self.num_train_timesteps)
y = torch.exp(-2.0 * ((t - (steps / 2.0)) / steps) ** 2)
y_shifted = y - self._y_min
weight = y_shifted / (self._weight_norm_const + self.eps)
if weight.numel() == 1:
return weight.reshape(())
return weight
def add_noise(
self, original_samples: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor
) -> torch.Tensor:
sigma = (timestep / float(self.num_train_timesteps)).to(
original_samples.device, dtype=original_samples.dtype
)
if sigma.ndim == 0:
return (1 - sigma) * original_samples + sigma * noise
sigma = sigma.view(-1, *([1] * (original_samples.ndim - 1)))
return (1 - sigma) * original_samples + sigma * noise
@staticmethod
def training_target(sample: torch.Tensor, noise: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
del timestep
return noise - sample
def build_inference_schedule(
self,
num_inference_steps: int,
device: torch.device,
dtype: torch.dtype,
shift_override: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be positive, got {num_inference_steps}")
shift = self.shift if shift_override is None else float(shift_override)
if shift <= 0:
raise ValueError(f"`shift` must be positive, got {shift}")
sigma_steps = torch.as_tensor(
_get_wan_sampling_sigmas(num_inference_steps, shift),
device=device,
dtype=torch.float32,
)
timesteps = sigma_steps * float(self.num_train_timesteps)
sigma_next = torch.cat([sigma_steps[1:], sigma_steps.new_zeros(1)])
deltas = sigma_next - sigma_steps
return timesteps.to(dtype=dtype), deltas.to(dtype=dtype)
@staticmethod
def step(model_output: torch.Tensor, delta: torch.Tensor, sample: torch.Tensor) -> torch.Tensor:
delta = delta.to(sample.device, dtype=sample.dtype)
if delta.ndim == 0:
return sample + model_output * delta
delta = delta.view(-1, *([1] * (sample.ndim - 1)))
return sample + model_output * delta
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
return rope_params(end, dim, theta)
def apply_dense_rope(x: torch.Tensor, freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float32).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
def _linear_input(linear: nn.Linear, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=linear.weight.dtype)
def _wan_layer_norm(norm: nn.Module, x: torch.Tensor) -> torch.Tensor:
if isinstance(norm, WanLayerNorm) and norm.weight is not None:
weight = norm.weight.float()
bias = norm.bias.float() if norm.bias is not None else None
return functional.layer_norm(x.float(), norm.normalized_shape, weight, bias, norm.eps).to(
dtype=x.dtype
)
return norm(x)
def create_group_causal_attn_mask(
num_temporal_groups: int, num_query_per_group: int, num_key_per_group: int, mode: str = "causal"
) -> torch.Tensor:
if mode not in ["causal", "group_diagonal"]:
raise ValueError(f"`mode` must be 'causal' or 'group_diagonal', got {mode}.")
if num_temporal_groups <= 0:
raise ValueError(f"`num_temporal_groups` must be positive, got {num_temporal_groups}.")
if num_query_per_group <= 0:
raise ValueError(f"`num_query_per_group` must be positive, got {num_query_per_group}.")
if num_key_per_group <= 0:
raise ValueError(f"`num_key_per_group` must be positive, got {num_key_per_group}.")
total_num_query_tokens = num_temporal_groups * num_query_per_group
total_num_key_tokens = num_temporal_groups * num_key_per_group
query_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_query_per_group).unsqueeze(1)
key_time_indices = torch.arange(num_temporal_groups).repeat_interleave(num_key_per_group).unsqueeze(0)
if mode == "causal":
attn_mask = query_time_indices >= key_time_indices
else:
attn_mask = query_time_indices == key_time_indices
if attn_mask.shape != (total_num_query_tokens, total_num_key_tokens):
raise RuntimeError("Attention mask shape mismatch.")
return attn_mask
class FastWAMAttentionBlock(WanAttentionBlock):
"""Wan attention block with FastWAM's arbitrary boolean mask support."""
def __init__(
self,
hidden_dim: int,
attn_head_dim: int,
num_heads: int,
ffn_dim: int,
eps: float = 1e-6,
fp32_attention: bool = True,
):
attention_dim = attn_head_dim * num_heads
if hidden_dim == attention_dim:
super().__init__(
dim=hidden_dim,
ffn_dim=ffn_dim,
num_heads=num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=eps,
)
else:
nn.Module.__init__(self)
self.dim = hidden_dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = (-1, -1)
self.qk_norm = True
self.cross_attn_norm = True
self.eps = eps
self.norm1 = WanLayerNorm(hidden_dim, eps)
self.self_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
self.norm3 = WanLayerNorm(hidden_dim, eps, elementwise_affine=True)
self.cross_attn = _FastWAMProjectedAttention(hidden_dim, attention_dim, num_heads, eps)
self.norm2 = WanLayerNorm(hidden_dim, eps)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.GELU(approximate="tanh"),
nn.Linear(ffn_dim, hidden_dim),
)
self.modulation = nn.Parameter(torch.randn(1, 6, hidden_dim) / hidden_dim**0.5)
self.attn_head_dim = attn_head_dim
self.fp32_attention = bool(fp32_attention)
@staticmethod
def split_modulation(block, t_mod: torch.Tensor):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
base_mod = block.modulation.to(dtype=t_mod.dtype, device=t_mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (base_mod + t_mod).chunk(
6, dim=chunk_dim
)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2),
scale_msa.squeeze(2),
gate_msa.squeeze(2),
shift_mlp.squeeze(2),
scale_mlp.squeeze(2),
gate_mlp.squeeze(2),
)
return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
def project_self_attention(
self, x: torch.Tensor, freqs: torch.Tensor | dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = self.self_attn.norm_q(self.self_attn.q(x))
k = self.self_attn.norm_k(self.self_attn.k(x))
v = self.self_attn.v(x)
if isinstance(freqs, dict):
b, s = x.shape[:2]
q = rope_apply(
q.view(b, s, self.num_heads, self.attn_head_dim),
freqs["grid_sizes"],
freqs["freqs"],
).flatten(2)
k = rope_apply(
k.view(b, s, self.num_heads, self.attn_head_dim),
freqs["grid_sizes"],
freqs["freqs"],
).flatten(2)
else:
q = apply_dense_rope(q, freqs, self.num_heads)
k = apply_dense_rope(k, freqs, self.num_heads)
return q, k, v
def apply_cross_attention(
self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor | None = None
) -> torch.Tensor:
if context_mask is not None and context_mask.dim() == 3:
context_mask = context_mask.unsqueeze(1)
attn = self.cross_attn
b, n, d = x.size(0), attn.num_heads, attn.head_dim
q = attn.norm_q(attn.q(x)).view(b, -1, n * d)
k = attn.norm_k(attn.k(context)).view(b, -1, n * d)
v = attn.v(context).view(b, -1, n * d)
x = fastwam_masked_attention(
q=q,
k=k,
v=v,
num_heads=n,
ctx_mask=context_mask,
fp32_attention=self.fp32_attention,
)
return attn.o(_linear_input(attn.o, x))
def project_self_attention_output(self, x: torch.Tensor) -> torch.Tensor:
return self.self_attn.o(_linear_input(self.self_attn.o, x))
def apply_norm1(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm1, x)
def apply_norm2(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm2, x)
def apply_norm3(self, x: torch.Tensor) -> torch.Tensor:
return _wan_layer_norm(self.norm3, x)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
t_mod: torch.Tensor,
freqs: torch.Tensor,
context_mask: torch.Tensor | None = None,
self_attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.split_modulation(self, t_mod)
residual_x = x
attn_input = modulate(self.apply_norm1(x), shift_msa, scale_msa)
q, k, v = self.project_self_attention(attn_input, freqs)
y = fastwam_masked_attention(
q=q,
k=k,
v=v,
num_heads=self.num_heads,
ctx_mask=self_attn_mask,
fp32_attention=self.fp32_attention,
)
x = residual_x + gate_msa * self.project_self_attention_output(y)
x = x + self.apply_cross_attention(self.apply_norm3(x), context, context_mask=context_mask)
mlp_input = modulate(self.apply_norm2(x), shift_mlp, scale_mlp)
return x + gate_mlp * self.ffn(mlp_input)
class _FastWAMProjectedAttention(nn.Module):
def __init__(self, hidden_dim: int, attention_dim: int, num_heads: int, eps: float):
super().__init__()
self.dim = hidden_dim
self.num_heads = num_heads
self.head_dim = attention_dim // num_heads
self.q = nn.Linear(hidden_dim, attention_dim)
self.k = nn.Linear(hidden_dim, attention_dim)
self.v = nn.Linear(hidden_dim, attention_dim)
self.o = nn.Linear(attention_dim, hidden_dim)
self.norm_q = WanRMSNorm(attention_dim, eps=eps)
self.norm_k = WanRMSNorm(attention_dim, eps=eps)
class WanVideoDiT(WanModel):
def __init__(
self,
hidden_dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: tuple[int, int, int],
num_heads: int,
attn_head_dim: int,
num_layers: int,
has_image_input: bool = False,
has_image_pos_emb: bool = False,
has_ref_conv: bool = False,
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
separated_timestep: bool = False,
require_vae_embedding: bool = False,
require_clip_embedding: bool = False,
fuse_vae_embedding_in_latents: bool = True,
action_conditioned: bool = False,
action_dim: int = 7,
action_group_causal_mask_mode="causal",
video_attention_mask_mode: str = "bidirectional",
use_gradient_checkpointing: bool = False,
fp32_attention: bool = True,
):
del in_dim_control_adapter
if has_image_input:
raise ValueError("FastWAM currently expects Wan2.2 TI2V latents with fused image conditioning.")
if has_image_pos_emb:
raise ValueError("FastWAM does not support extra image positional embeddings in WanVideoDiT.")
if has_ref_conv:
raise ValueError("FastWAM does not support reference convolutions in WanVideoDiT.")
if add_control_adapter:
raise ValueError("FastWAM does not support control adapters in WanVideoDiT.")
if require_clip_embedding:
raise ValueError("FastWAM does not support CLIP embedding conditioning in WanVideoDiT.")
if require_vae_embedding or not fuse_vae_embedding_in_latents:
raise ValueError("FastWAM expects VAE conditioning to be fused in latents.")
if attn_head_dim != hidden_dim // num_heads:
raise ValueError(
"`attn_head_dim` must match the upstream Wan head dimension `hidden_dim // num_heads`; "
f"got {attn_head_dim} vs {hidden_dim // num_heads}."
)
super().__init__(
model_type="ti2v",
patch_size=patch_size,
text_len=512,
in_dim=in_dim,
dim=hidden_dim,
ffn_dim=ffn_dim,
freq_dim=freq_dim,
text_dim=text_dim,
out_dim=out_dim,
num_heads=num_heads,
num_layers=num_layers,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=eps,
)
self.blocks = torch.nn.ModuleList(
[
FastWAMAttentionBlock(
hidden_dim=hidden_dim,
attn_head_dim=attn_head_dim,
num_heads=num_heads,
ffn_dim=ffn_dim,
eps=eps,
fp32_attention=fp32_attention,
)
for _ in range(num_layers)
]
)
self.init_weights()
self.hidden_dim = hidden_dim
self.attn_head_dim = attn_head_dim
self.separated_timestep = separated_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.video_attention_mask_mode = str(video_attention_mask_mode)
self.action_conditioned = action_conditioned
self.action_dim = action_dim
self.fp32_attention = bool(fp32_attention)
if self.action_conditioned:
self.action_embedding = torch.nn.Linear(action_dim, hidden_dim)
self.action_group_causal_mask_mode = action_group_causal_mask_mode
self.use_gradient_checkpointing = use_gradient_checkpointing
if self.use_gradient_checkpointing:
logger.info(
"Using gradient checkpointing for DiT blocks. This will save memory but use more computation."
)
def patchify(self, x: torch.Tensor):
return self.patch_embedding(x)
def _validate_forward_inputs(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None,
action: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if x.ndim != 5:
raise ValueError(f"`latents` must be 5D [B, C, T, H, W], got shape {tuple(x.shape)}")
num_latent_frames = x.shape[2]
if context.ndim != 3:
raise ValueError(f"`context` must be 3D [B, L, D], got shape {tuple(context.shape)}")
if timestep.ndim != 1:
raise ValueError(f"`timestep` must be 1D [B] or [1], got shape {tuple(timestep.shape)}")
if self.action_conditioned:
allow_text_only_single_frame = num_latent_frames == 1 and action is None
if not allow_text_only_single_frame:
if action is None:
raise ValueError("Action input is required for action-conditioned model.")
if action.ndim != 3:
raise ValueError(
f"`action` must be 3D [B, action_horizon, action_dim], got shape {tuple(action.shape)}"
)
if action.shape[2] != self.action_dim:
raise ValueError(
f"`action` last dimension must be {self.action_dim}, got {action.shape[2]}"
)
if num_latent_frames <= 1:
raise ValueError(
f"video length must be > 1 for action-conditioned model, got {num_latent_frames}"
)
if action.shape[1] % (num_latent_frames - 1) != 0:
raise ValueError(
"action horizon must be divisible by (num_latent_frames - 1), "
f"got action_horizon={action.shape[1]}"
)
if context_mask is None:
context_mask = torch.ones(
(context.shape[0], context.shape[1]), dtype=torch.bool, device=context.device
)
else:
if context_mask.ndim != 2:
raise ValueError(f"`context_mask` must be 2D [B, L], got shape {tuple(context_mask.shape)}")
if context_mask.shape[0] != context.shape[0] or context_mask.shape[1] != context.shape[1]:
raise ValueError(
"`context_mask` shape must match `context` shape [B, L], "
f"got {tuple(context_mask.shape)} vs {tuple(context.shape)}"
)
batch_size = x.shape[0]
if batch_size != context.shape[0]:
if not self.training and batch_size == 1:
x = x.expand(context.shape[0], -1, -1, -1, -1)
batch_size = context.shape[0]
else:
raise ValueError(
f"Batch mismatch between latents and context: {batch_size} vs {context.shape[0]}."
)
if timestep.shape[0] not in (1, batch_size):
raise ValueError(
f"`timestep` length must be 1 or batch_size({batch_size}), got {timestep.shape[0]}"
)
if timestep.shape[0] == 1 and batch_size > 1:
if self.training:
raise ValueError("During training, timestep length must match batch_size.")
timestep = timestep.expand(batch_size)
return x, timestep, context_mask
def build_video_to_video_mask(
self,
video_seq_len: int,
video_tokens_per_frame: int,
device: torch.device,
) -> torch.Tensor:
if video_seq_len <= 0:
raise ValueError(f"`video_seq_len` must be positive, got {video_seq_len}")
if video_tokens_per_frame <= 0:
raise ValueError(f"`video_tokens_per_frame` must be positive, got {video_tokens_per_frame}")
if self.video_attention_mask_mode == "bidirectional":
return torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
if self.video_attention_mask_mode == "per_frame_causal":
if video_seq_len % video_tokens_per_frame != 0:
raise ValueError(
"`video_seq_len` must be divisible by `video_tokens_per_frame` in `per_frame_causal` mode, "
f"got {video_seq_len} and {video_tokens_per_frame}"
)
num_video_frames = video_seq_len // video_tokens_per_frame
frame_causal = torch.tril(
torch.ones((num_video_frames, num_video_frames), dtype=torch.bool, device=device)
)
return frame_causal.repeat_interleave(video_tokens_per_frame, dim=0).repeat_interleave(
video_tokens_per_frame, dim=1
)
if self.video_attention_mask_mode == "first_frame_causal":
video_mask = torch.ones((video_seq_len, video_seq_len), dtype=torch.bool, device=device)
first_frame_tokens = min(video_tokens_per_frame, video_seq_len)
video_mask[:first_frame_tokens, first_frame_tokens:] = False
return video_mask
raise ValueError(f"Unsupported video attention mask mode: {self.video_attention_mask_mode}")
def pre_dit(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
action: torch.Tensor | None = None,
fuse_vae_embedding_in_latents: bool = False,
) -> dict[str, Any]:
x, timestep, context_mask = self._validate_forward_inputs(
x=x,
timestep=timestep,
context=context,
context_mask=context_mask,
action=action,
)
model_dtype = self.patch_embedding.weight.dtype
x = x.to(dtype=model_dtype)
context = context.to(dtype=model_dtype)
if action is not None:
action = action.to(dtype=model_dtype)
batch_size = x.shape[0]
patch_h = int(self.patch_size[1])
patch_w = int(self.patch_size[2])
if x.shape[3] % patch_h != 0 or x.shape[4] % patch_w != 0:
raise ValueError(
"Latent spatial shape must be divisible by DiT patch size, "
f"got HxW=({x.shape[3]}, {x.shape[4]}), patch=({patch_h}, {patch_w})"
)
tokens_per_frame = (x.shape[3] // patch_h) * (x.shape[4] // patch_w)
if not (self.separated_timestep and fuse_vae_embedding_in_latents):
raise NotImplementedError(
"FastWAM currently requires separated timesteps with fused VAE latents."
)
token_timesteps = torch.ones(
(batch_size, x.shape[2], tokens_per_frame),
dtype=model_dtype,
device=timestep.device,
) * timestep.to(dtype=model_dtype).view(batch_size, 1, 1)
token_timesteps[:, 0, :] = 0
token_timesteps = token_timesteps.reshape(batch_size, -1)
# Wan keeps the time embedding in fp32: the AdaLN modulation in the vendored
# Head/Block asserts e.dtype == float32 (numerical stability of the scale/shift).
# Upstream guarantees this via an fp32 autocast region, so it holds even when the
# model runs in bf16. Mirror that here, then cast the per-block modulation back to
# model_dtype so the bf16 attention blocks are not upcast to fp32.
with torch.amp.autocast("cuda", dtype=torch.float32):
token_t_emb = sinusoidal_embedding_1d(self.freq_dim, token_timesteps.reshape(-1)).float()
t = self.time_embedding(token_t_emb).reshape(batch_size, -1, self.hidden_dim)
t_mod = self.time_projection(t).unflatten(2, (6, self.hidden_dim))
t_mod = t_mod.to(dtype=model_dtype)
x = self.patchify(x)
f, h, w = x.shape[2:]
context = self.text_embedding(context)
context_len = context.shape[1]
if self.action_conditioned and action is not None:
action_len = action.shape[1]
action_emb = self.action_embedding(action)
action_pos_embed = sinusoidal_embedding_1d(
self.hidden_dim, torch.arange(action_len, device=action_emb.device)
).to(dtype=action_emb.dtype)
action_emb = action_emb + action_pos_embed.unsqueeze(0)
context = torch.cat([context, action_emb], dim=1)
num_temporal_groups = f - 1
if num_temporal_groups <= 0:
raise ValueError(
"Action-conditioned context mask requires at least 2 latent frames when `action` is provided."
)
if action_emb.shape[1] % num_temporal_groups != 0:
raise ValueError(
f"Action embedding length {action_emb.shape[1]} must be divisible by "
f"number of temporal groups {num_temporal_groups}"
)
action_group_mask = create_group_causal_attn_mask(
num_temporal_groups=num_temporal_groups,
num_query_per_group=tokens_per_frame,
num_key_per_group=action_len // num_temporal_groups,
mode=self.action_group_causal_mask_mode,
).to(context.device)
seq_len = f * h * w
final_context_mask = torch.zeros(
(batch_size, seq_len, context.shape[1]), dtype=torch.bool, device=context.device
)
final_context_mask[:, :, :context_len] = context_mask.unsqueeze(1).expand(-1, seq_len, -1)
final_context_mask[:, tokens_per_frame:, context_len:] = action_group_mask.unsqueeze(0).expand(
batch_size, -1, -1
)
context_mask = final_context_mask
elif self.action_conditioned and action is None:
if f != 1:
raise ValueError(
"Action-conditioned model requires `action` unless running single-frame text-only mode "
"with num_latent_frames=1."
)
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
else:
context_mask = context_mask.unsqueeze(1).expand(-1, f * h * w, -1)
x_tokens = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
grid_sizes = torch.tensor([[f, h, w]] * batch_size, dtype=torch.long, device=x_tokens.device)
freqs = {"grid_sizes": grid_sizes, "freqs": self.freqs.to(x_tokens.device)}
return {
"tokens": x_tokens,
"freqs": freqs,
"t": t,
"t_mod": t_mod,
"context": context,
"context_mask": context_mask,
"meta": {
"grid_sizes": grid_sizes,
"tokens_per_frame": tokens_per_frame,
"batch_size": batch_size,
},
}
def post_dit(self, x_tokens: torch.Tensor, pre_state: dict[str, Any]) -> torch.Tensor:
x = self.head(x_tokens, pre_state["t"])
return torch.stack(super().unpatchify(x, pre_state["meta"]["grid_sizes"]))
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
context_mask: torch.Tensor | None = None,
action: torch.Tensor | None = None,
fuse_vae_embedding_in_latents: bool = False,
):
pre_state = self.pre_dit(
x=x,
timestep=timestep,
context=context,
context_mask=context_mask,
action=action,
fuse_vae_embedding_in_latents=fuse_vae_embedding_in_latents,
)
x_tokens = pre_state["tokens"]
context_emb = pre_state["context"]
t_mod = pre_state["t_mod"]
freqs = pre_state["freqs"]
context_attn_mask = pre_state["context_mask"]
self_attn_mask = (
self.build_video_to_video_mask(
video_seq_len=x_tokens.shape[1],
video_tokens_per_frame=int(pre_state["meta"]["tokens_per_frame"]),
device=x_tokens.device,
)
if self.video_attention_mask_mode != "bidirectional"
else None
)
for block in self.blocks:
if self.use_gradient_checkpointing:
x_tokens = gradient_checkpoint_forward(
block,
self.use_gradient_checkpointing,
x_tokens,
context_emb,
t_mod,
freqs,
context_mask=context_attn_mask,
self_attn_mask=self_attn_mask,
)
else:
x_tokens = block(
x_tokens,
context_emb,
t_mod,
freqs,
context_mask=context_attn_mask,
self_attn_mask=self_attn_mask,
)
return self.post_dit(x_tokens, pre_state)
__all__ = [
"FastWAMAttentionBlock",
"WanContinuousFlowMatchScheduler",
"WanVideoDiT",
"apply_dense_rope",
"create_group_causal_attn_mask",
"fastwam_masked_attention",
"gradient_checkpoint_forward",
"modulate",
"precompute_freqs_cis",
"sinusoidal_embedding_1d",
]
+1 -9
View File
@@ -18,12 +18,4 @@ from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
def __getattr__(name: str):
if name in {"GR00TN17", "GR00TN17Config"}:
from .groot_n1_7 import GR00TN17, GR00TN17Config
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
b, t = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
@@ -14,7 +14,6 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import torch
@@ -43,9 +42,6 @@ else:
Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot")
@@ -185,7 +181,8 @@ class BasicTransformerBlock(nn.Module):
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
@@ -269,8 +266,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
logger.debug(
"Total number of DiT parameters: %d",
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -321,71 +318,6 @@ class DiT(ModelMixin, ConfigMixin):
return self.proj_out_2(hidden_states)
class AlternateVLDiT(DiT):
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
super().__init__(*args, **kwargs)
self.attend_text_every_n_blocks = attend_text_every_n_blocks
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
image_mask: torch.Tensor | None = None,
backbone_attention_mask: torch.Tensor | None = None,
):
if image_mask is None:
raise ValueError("image_mask is required for AlternateVLDiT.")
if backbone_attention_mask is None:
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
image_attention_mask = image_mask & backbone_attention_mask
non_image_attention_mask = (~image_mask) & backbone_attention_mask
all_hidden_states = [hidden_states]
if not self.config.interleave_self_attention:
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
curr_encoder_attention_mask = (
non_image_attention_mask
if idx % (2 * self.attend_text_every_n_blocks) == 0
else image_attention_mask
)
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=curr_encoder_attention_mask,
temb=temb,
)
all_hidden_states.append(hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@@ -430,8 +362,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers)
]
)
logger.debug(
"Total number of SelfAttentionTransformer parameters: %d",
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -0,0 +1,408 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import field
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
else:
PretrainedConfig = object
BatchFeature = None
from .action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from .cross_attention_dit import DiT, SelfAttentionTransformer
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
b, t, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == b:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, t)
else:
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
backbone_embedding_dim: int = field(
default=1536, metadata={"help": "Backbone embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class FlowmatchingActionHead(nn.Module):
config_class = FlowmatchingActionHeadConfig
supports_gradient_checkpointing = True
def __init__(
self,
config: FlowmatchingActionHeadConfig,
):
super().__init__()
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
self.model = DiT(**config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
self.vl_self_attention = (
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
for p in self.parameters():
p.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
print(f"Tune action head projector: {self.tune_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_projector and not tune_diffusion_model:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Action head trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
def sample_time(self, batch_size, device, dtype):
if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = backbone_output["backbone_features"]
backbone_features = self.vlln(backbone_features)
backbone_features = self.vl_self_attention(backbone_features)
backbone_output["backbone_features"] = backbone_features
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
# Set frozen modules to eval
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
if self.config.expand_batch is not None:
for k, v in backbone_output.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
backbone_output[k] = expanded
for k, v in action_input.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
action_input[k] = expanded
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
device = vl_embs.device
# Get embodiment ID.
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Embed noised action trajectory.
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
vl_attn_mask = backbone_output.backbone_attention_mask
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=vl_attn_mask,
timestep=t_discretized,
return_all_hidden_states=False, # NOTE (YL): not using flare now
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
action_mask = action_input.action_mask
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = loss.sum() / action_mask.sum()
output_dict = {
"loss": loss,
}
return BatchFeature(data=output_dict)
@torch.no_grad()
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Set initial actions as the sampled noise.
batch_size = vl_embs.shape[0]
device = vl_embs.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# Embed noised action trajectory.
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
# Run model forward.
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
# Update actions using euler integration.
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype
+37 -337
View File
@@ -14,228 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from .utils import read_json
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
"n1_7": GROOT_N1_7,
"n1d7": GROOT_N1_7,
"n17": GROOT_N1_7,
"1.7": GROOT_N1_7,
}
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
}
def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
return normalized
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
if transform is None:
return None
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
supported = ", ".join(
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
)
raise ValueError(
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
f"Supported transforms: none, {supported}."
)
return normalized
def infer_groot_model_version(model_path: str | None) -> str | None:
if not model_path:
return None
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
# N1.5 is unsupported, but it must still be recognised here to fail loudly
# rather than silently treating an N1.5 checkpoint as N1.7.
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
return None
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
if model_path is None:
return False
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return False
config = read_json(config_path)
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
processor_config = read_json(processor_config_path)
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if "libero_sim" in modality_configs:
return "libero_sim"
if len(modality_configs) == 1:
return next(iter(modality_configs))
return None
def infer_groot_n1_7_action_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
processor_config = read_json(processor_config_path)
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
return None
modality_configs = processor_kwargs.get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag is None:
return None
embodiment_config = modality_configs.get(embodiment_tag, {})
if not isinstance(embodiment_config, dict):
return None
action_config = embodiment_config.get("action", {})
if not isinstance(action_config, dict):
return None
delta_indices = action_config.get("delta_indices", [])
if not isinstance(delta_indices, list):
return None
return len(delta_indices) or None
def infer_groot_n1_7_action_execution_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
if action_horizon is None:
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag == "libero_sim":
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
# actions. Keeping that execution cadence avoids stale open-loop chunks.
return min(action_horizon, 8)
return action_horizon
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return None
return _infer_groot_model_version_from_config(read_json(config_path))
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try:
return normalize_groot_model_version(model_version)
except ValueError:
return None
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
for candidate in candidates:
if not isinstance(candidate, str):
continue
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None
@PreTrainedConfig.register_subclass("groot")
@dataclass
@@ -244,44 +28,35 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 40
n_action_steps: int = 40
chunk_size: int = 50
n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 132
max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 132
max_action_dim: int = 32
# GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
# Normalization (start with identity, adjust as needed)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Groot-specific model parameters
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
# model *source*, and is intentionally distinct from the inherited `pretrained_path`:
# `pretrained_path` (`--policy.path`) points at a saved LeRobot checkpoint directory whose
# `config.json` carries a `type` field, whereas a raw NVIDIA GR00T checkpoint has no such
# field and so can only be loaded through `base_model_path` (`--policy.base_model_path`).
# Defaults to GROOT_N1_7_BASE_MODEL when unset (resolved in __post_init__).
base_model_path: str | None = None
# Groot-specific model parameters (from groot_finetune_script.py)
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -300,31 +75,20 @@ class GrootConfig(PreTrainedConfig):
# Whether to fine-tune the diffusion model
tune_diffusion_model: bool = True
# Whether to fine-tune the VL LayerNorm + VL self-attention projector in the action head.
tune_vlln: bool = True
# LoRA parameters (from groot_finetune_script.py)
# Rank for the LORA model. If 0, no LORA will be used.
lora_rank: int = 0
# Number of top LLM backbone layers to fine-tune (0 = none). Lets you adapt just the final
# language layers without unfreezing the whole backbone; independent of `tune_llm`, which tunes
# the entire LLM.
tune_top_llm_layers: int = 0
# Alpha value for the LORA model
lora_alpha: int = 16
# Inference-time knob: Number of flow-matching denoising steps used to decode an action chunk.
# Trades inference latency for action quality.
# None keeps the checkpoint value (GR00T N1.7 default: 4).
num_inference_timesteps: int | None = None
# Dropout rate for the LORA model
lora_dropout: float = 0.1
# Inference-time knob: Real-Time Chunking (RTC) overlap-blend ramp rate, used when the RTC engine
# supplies a previous-chunk prefix. Higher values blend the overlapping prefix more aggressively.
# None keeps the checkpoint value (GR00T N1.7 default: 6.0).
rtc_ramp_rate: float | None = None
# Whether to use the full model for LORA
lora_full_model: bool = False
# Inference-time knob: Whether to request the flash-attention-2 kernel for the Qwen3-VL backbone.
# flash-attn is an optional, user-managed optimization; when it is absent (the default),
# the backbone transparently falls back to SDPA, which is numerically equivalent.
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
use_flash_attention: bool = False
# Training parameters
# Training parameters (matching groot_finetune_script.py)
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
@@ -332,22 +96,17 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
# unused by the LeRobot N1.7 implementation except the `tokenizer_assets_repo` N1.5 tripwire and
# the `image_size` legacy remap in __post_init__. They are kept ONLY so a config.json saved by an
# earlier lerobot release (notably a GR00T N1.5 checkpoint) still parses under draccus — which
# rejects unknown fields — and is then rejected with a clear N1.5 removal message rather than an
# opaque draccus decoding error.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
lora_rank: int = 0
lora_alpha: int = 16
lora_dropout: float = 0.1
lora_full_model: bool = False
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
@@ -358,65 +117,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = GROOT_N1_7_BASE_MODEL
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
# wrapper applies this conversion; mirror it here so eval on the
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
# 'none' (normalized to None above) keeps the transform disabled.
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
# warning. The dataclass defaults are already the N1.7 values, so a plain
# GrootConfig() never triggers this.
legacy_default_remaps = (
("max_state_dim", 64, 132),
("max_action_dim", 32, 132),
("chunk_size", 50, 40),
("n_action_steps", 50, 40),
("image_size", (224, 224), (256, 256)),
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != GROOT_N1_7:
message = (
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__()
if self.n_action_steps > self.chunk_size:
@@ -424,6 +124,9 @@ class GrootConfig(PreTrainedConfig):
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
)
# groot_repo_path is now optional since we ported the components
# No validation needed
def validate_features(self) -> None:
"""Validate and set up input/output features for Groot."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
@@ -489,10 +192,7 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
return list(range(min(self.chunk_size, 16)))
@property
def reward_delta_indices(self) -> None:
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Eagle25VLConfig(PretrainedConfig):
model_type = "eagle_2_5_vl"
is_composition = True
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
def __init__(
self,
vision_config=None,
text_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-4,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
loss_version="v1",
min_dynamic_tiles=1,
max_dynamic_tiles=6,
mlp_checkpoint=False,
initializer_range=0.02,
_attn_implementation="flash_attention_2",
_attn_implementation_autoset=False,
llm_config=None,
image_token_index=None,
use_pixel_shuffle=True,
mlp_connector_layers=2,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"model_type": "siglip_vision_model"}
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
if text_config is None:
text_config = {"architectures": ["Qwen2ForCausalLM"]}
logger.info(
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
if text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = LlamaConfig(**text_config)
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
self.text_config = Qwen2Config(**text_config)
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
self.text_config = Qwen3Config(**text_config)
else:
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.mlp_checkpoint = mlp_checkpoint
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.loss_version = loss_version
self.initializer_range = initializer_range
self.min_dynamic_tiles = min_dynamic_tiles
self.max_dynamic_tiles = max_dynamic_tiles
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self._attn_implementation = _attn_implementation
self._attn_implementation_autoset = _attn_implementation_autoset
self.image_token_index = image_token_index
self.use_pixel_shuffle = use_pixel_shuffle
self.mlp_connector_layers = mlp_connector_layers
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["pad2square"] = self.pad2square
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["min_dynamic_tiles"] = self.min_dynamic_tiles
output["max_dynamic_tiles"] = self.max_dynamic_tiles
output["tie_word_embeddings"] = self.tie_word_embeddings
output["_attn_implementation"] = self._attn_implementation
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
output["use_pixel_shuffle"] = self.use_pixel_shuffle
output["mlp_connector_layers"] = self.mlp_connector_layers
return output
@@ -0,0 +1,503 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
ImagesKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
validate_kwargs,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_v2_available,
)
from transformers.video_utils import VideoInput
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F # noqa: N812
from transformers.image_utils import pil_torch_interpolation_mapping
else:
from torchvision.transforms import functional as F # noqa: N812
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
"""Crop the given numpy array.
Args:
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
left (int): The left coordinate of the crop box.
top (int): The top coordinate of the crop box.
right (int): The right coordinate of the crop box.
bottom (int): The bottom coordinate of the crop box.
Returns:
torch.Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
if img.ndim not in [2, 3]:
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
img_height = img.shape[1]
img_width = img.shape[2]
if top < 0 or left < 0 or bottom > img_height or right > img_width:
raise ValueError("Crop coordinates out of bounds")
if top >= bottom or left >= right:
raise ValueError("Invalid crop coordinates")
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
pad_during_tiling: bool | None
do_pad: bool | None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
max_dynamic_tiles = 12
min_dynamic_tiles = 1
use_thumbnail = True
pad_during_tiling = False
valid_kwargs = Eagle25VLFastImageProcessorKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
"""
max_dynamic_tiles (`int`, *optional*):
The maximum number of dynamic tiles to use for processing high resolution images.
min_dynamic_tiles (`int`, *optional*):
The minimum number of dynamic tiles to use for processing high resolution images.
use_thumbnail (`bool`, *optional*):
Whether to use a thumbnail for processing high resolution images.
pad_during_tiling (`bool`, *optional*):
Whether to pad the image during tiling.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
# NOTE(YL): we will overload the preprocess method to add the image_flags
# def preprocess(
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
# ) -> BatchFeature:
# return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: torch.Tensor,
target_resolution: tuple,
interpolation: F.InterpolationMode,
input_data_format: ChannelDimension,
) -> torch.Tensor:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly focus on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def _pad_for_patching(
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: torch.Tensor,
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: F.InterpolationMode,
pad_during_tiling: bool,
) -> list[torch.Tensor]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
# calculate the target width and height
target_width = tile_size * target_aspect_ratio[0]
target_height = tile_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if pad_during_tiling:
resized_image = self._resize_for_patching(
image,
(target_height, target_width),
interpolation=interpolation,
input_data_format=ChannelDimension.FIRST,
)
padded_image = self._pad_for_patching(
resized_image,
(target_height, target_width),
input_data_format=ChannelDimension.FIRST,
)
image_used_to_split = padded_image
else:
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
processed_tiles = []
for i in range(blocks):
box = (
(i % (target_width // tile_size)) * tile_size,
(i // (target_width // tile_size)) * tile_size,
((i % (target_width // tile_size)) + 1) * tile_size,
((i // (target_width // tile_size)) + 1) * tile_size,
)
# split the image
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
processed_tiles.append(split_img)
assert len(processed_tiles) == blocks
if use_thumbnail and len(processed_tiles) != 1:
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
processed_tiles.append(thumbnail_img)
return processed_tiles
def _pad_for_batching(
self,
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: F.InterpolationMode | None,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
do_pad: bool,
return_tensors: str | TensorType | None,
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
tile_size = crop_size.height
elif size and size.height:
tile_size = size.height
else:
tile_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
min_num=min_dynamic_tiles,
max_num=max_dynamic_tiles,
size=size_tuple,
tile_size=tile_size,
use_thumbnail=use_thumbnail,
interpolation=interpolation,
pad_during_tiling=pad_during_tiling,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
# Added for transformers >=4.53.0 compatibility
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches,
disable_grouping=disable_grouping,
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(
processed_image_patches_grouped, grouped_image_patches_index
)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes},
tensor_type=return_tensors,
)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
if images is not None:
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
if videos is not None:
videos = self._prepare_image_like_inputs(
images=videos,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
# Added for transformers >=4.53.0 compatibility
resample = kwargs.pop("resample", self.resample)
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, PILImageResampling | int)
else resample
)
# Filter kwargs to only include those accepted by _preprocess
valid_preprocess_kwargs = {
"do_resize",
"size",
"max_dynamic_tiles",
"min_dynamic_tiles",
"use_thumbnail",
"pad_during_tiling",
"interpolation",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"pad_size",
"disable_grouping",
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
if images is not None:
return self._preprocess(images, **filtered_kwargs)
elif videos is not None:
return self._preprocess(videos, **filtered_kwargs)
__all__ = ["Eagle25VLImageProcessorFast"]
@@ -0,0 +1,396 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import inspect
import torch
import torch.utils.checkpoint as cp
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.utils import add_start_docstrings, logging
from .configuration_eagle2_5_vl import Eagle25VLConfig
logger = logging.get_logger(__name__)
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
EAGLE2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Eagle25VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
EAGLE2_5_VL_START_DOCSTRING,
)
class Eagle25VLPreTrainedModel(PreTrainedModel):
config_class = Eagle25VLConfig
base_model_prefix = "model"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = [
"Qwen2DecoderLayer",
"LlamaDecoderLayer",
"Siglip2EncoderLayer",
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear | nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
config_class = Eagle25VLConfig
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
if config.use_pixel_shuffle:
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
else:
self.num_image_token = int((image_size // patch_size) ** 2)
self.select_layer = config.select_layer
self.downsample_ratio = config.downsample_ratio
self.loss_version = config.loss_version
self.mlp_checkpoint = config.mlp_checkpoint
self.use_pixel_shuffle = config.use_pixel_shuffle
self.mlp_connector_layers = config.mlp_connector_layers
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
config.vision_config._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
if language_model is not None:
self.language_model = language_model
else:
if config.text_config.architectures[0] == "LlamaForCausalLM":
self.language_model = LlamaForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
assert config.text_config._attn_implementation == "flash_attention_2", (
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
)
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
else:
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
if config.mlp_connector_layers == 2:
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size, llm_hidden_size),
)
else:
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
self.image_token_index = config.image_token_index
self.neftune_alpha = None
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
self.use_llm_lora = config.use_llm_lora
if config.use_llm_lora:
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
self.check_forward_kwargs()
def check_forward_kwargs(self):
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
# has special handling for functions with **kwargs parameters that would affect
# how our model is processed during training and inference.
forward_params = inspect.signature(self.forward).parameters
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.out_proj",
"mlp.fc1",
"mlp.fc2",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM",
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
self.use_llm_lora = True
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_flags: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
num_tiles_list: list[torch.Tensor] | None = None,
) -> tuple | CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
if image_flags is not None:
image_flags = image_flags.view(-1)
vit_embeds = vit_embeds[image_flags == 1]
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.image_token_index
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, c)
print(
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}"
)
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(b, n, c)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
)
if hasattr(vit_embeds, "last_hidden_state"):
vit_embeds = vit_embeds.last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
if self.use_pixel_shuffle:
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
if self.mlp_checkpoint and vit_embeds.requires_grad:
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
else:
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor | None = None,
input_ids: torch.FloatTensor | None = None,
attention_mask: torch.LongTensor | None = None,
visual_features: torch.FloatTensor | None = None,
generation_config: GenerationConfig | None = None,
output_hidden_states: bool | None = None,
image_sizes: list[tuple[int, int]] | None = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.config.image_token_index
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
input_embeds = input_embeds.reshape(b, n, c)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
if "use_cache" not in generate_kwargs:
generate_kwargs["use_cache"] = True
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
**generate_kwargs,
)
return outputs
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
@@ -0,0 +1,541 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Eagle25VL.
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
"""
import base64
import os
import re
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 256
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == "RGBA":
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
image = ele["image"] if "image" in ele else ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True, timeout=10)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = to_rgb(image_obj)
if "scale_factor" in ele:
scale_factor = ele["scale_factor"]
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
return image
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"videos_kwargs": {"max_dynamic_tiles": 1},
}
class Eagle25VLProcessor(ProcessorMixin):
r"""
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
Args:
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
num_image_tokens (`int`, *optional*):
Number of image tokens for one imagethat will be returned by vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"num_image_tokens",
"vision_feature_select_strategy",
"image_token",
"video_token",
"images_kwargs",
"videos_kwargs",
"text_kwargs",
]
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<IMG_CONTEXT>", # nosec: B107
video_token="<IMG_CONTEXT>", # nosec: B107
tokens_per_tile=256,
image_placeholder="image",
video_placeholder="video",
image_start_token="<img>",
image_end_token="</img>",
**kwargs,
):
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.image_placeholder = image_placeholder
self.video_placeholder = video_placeholder
self.tokens_per_tile = tokens_per_tile
self.image_start_token = image_start_token
self.image_end_token = image_end_token
if "auto_map" in kwargs:
self.auto_map = kwargs["auto_map"]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def replace_media_placeholder(
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
):
num_of_images_in_this_sample = 0
num_of_videos_in_this_sample = 0
# Regular expression pattern to match formats like <image-1> or <video-2>
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
unified_frame_list = []
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
# )
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
# )
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
# "use_thumbnail", self.image_processor.use_thumbnail
# )
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
)
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
)
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
"use_thumbnail", self.image_processor.use_thumbnail
)
tile_size = self.image_processor.size.get("height", 448)
# Function to replace tags in a single text
def replace_in_text(text):
# repl callback function for each match replacement operation
def repl(match):
nonlocal unified_frame_list
nonlocal num_of_images_in_this_sample
nonlocal num_of_videos_in_this_sample
media_type = match.group(1) # 'image' or 'video'
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
# Select the corresponding path based on media type
idx_mapper = {
0: "first",
1: "second",
2: "third",
3: "fourth",
4: "fifth",
5: "sixth",
6: "seventh",
7: "eighth",
8: "ninth",
9: "tenth",
}
if media_type == "image":
image_inputs = self.image_processor(
images=[image_list[idx_in_list]],
videos=None,
**output_kwargs["images_kwargs"],
)
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
num_of_images_in_this_sample += 1
elif media_type == "video":
video_inputs = self.image_processor(
images=None,
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
frame_timestamps = timestamps_list[idx_in_list]
else:
frame_timestamps = None
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
num_of_tiles_each_frame = [
self.get_number_tiles_based_on_image_size(
image_size,
video_min_dynamic_tiles,
video_max_dynamic_tiles,
video_use_thumbnail,
tile_size,
)
for image_size in image_sizes
]
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
)
if frame_timestamps is not None:
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
)
special_placeholder = [
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
else:
special_placeholder = [
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
if sampled_fps is not None:
special_placeholder = (
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
+ "".join(special_placeholder)
)
else:
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
special_placeholder
)
unified_frame_list.append(video_inputs)
num_of_videos_in_this_sample += 1
else:
raise ValueError(f"Unknown media type: {media_type}")
return special_placeholder
return pattern.sub(repl, text)
text = replace_in_text(text)
if len(unified_frame_list) > 0:
def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
return (
text,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Eagle25VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text_list = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
elif isinstance(text, list) and isinstance(text[0], str):
text_list = text
if images is None:
images = []
if videos is None:
videos = []
pixel_values_list = []
image_sizes_list = []
new_sample_list = []
image_start_idx = 0
video_start_idx = 0
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
for sample in text_list:
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
(
sample,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
) = self.replace_media_placeholder(
sample,
images[image_start_idx:],
videos[video_start_idx:],
timestamps_list,
fps_list,
**output_kwargs,
)
new_sample_list.append(sample)
if pixel_values is not None:
pixel_values_list.append(pixel_values)
image_sizes_list.append(image_sizes)
image_start_idx += num_of_images_in_this_sample
video_start_idx += num_of_videos_in_this_sample
if len(pixel_values_list) > 0:
image_inputs = {
"pixel_values": torch.cat(pixel_values_list),
"image_sizes": torch.cat(image_sizes_list),
}
else:
image_inputs = {}
video_inputs = {}
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
def get_number_tiles_based_on_image_size(
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
) -> int:
"""
Get the number of tiles based on the image size.
"""
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and tiles_num > 1:
tiles_num += 1
return tiles_num
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# override to save video-config in a separate config file
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
outputs = super().save_pretrained(save_directory, **kwargs)
return outputs
# override to load video-config from a separate config file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
return processor
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def process_vision_info(
self,
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
vision_infos = self.extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
video_timestamps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return (
image_inputs,
video_inputs,
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
)
return image_inputs, video_inputs
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
__all__ = ["Eagle25VLProcessor"]
+380
View File
@@ -0,0 +1,380 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
from .action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from .utils import ensure_eagle_cache_ready
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if ACTION in inputs:
action = inputs[ACTION]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model
-933
View File
@@ -1,933 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import logging
from contextlib import suppress
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available:
from transformers import (
AutoConfig,
AutoModel,
PretrainedConfig,
PreTrainedModel,
Qwen3VLConfig,
Qwen3VLForConditionalGeneration,
)
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
try:
import tree
except ImportError:
tree = None
logger = logging.getLogger(__name__)
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
"model_name": "nvidia/Cosmos-Reason2-2B",
"backbone_model_type": "qwen",
"model_revision": None,
"tune_top_llm_layers": 0,
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 16,
"reproject_vision": False,
"use_flash_attention": False,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
"color_jitter_params": None,
"use_albumentations_transforms": True,
"extra_augmentation_config": None,
"formalize_language": True,
"apply_sincos_state_encoding": False,
"use_percentiles": True,
"use_relative_action": False,
"max_state_dim": 132,
"max_action_dim": 132,
"action_horizon": 40,
"hidden_size": 1024,
"input_embedding_dim": 1536,
"state_history_length": 1,
"add_pos_embed": True,
"attn_dropout": 0.2,
"use_vlln": True,
"max_seq_len": 1024,
"use_alternate_vl_dit": True,
"attend_text_every_n_blocks": 2,
"diffusion_model_cfg": {
"positional_embeddings": None,
"num_layers": 32,
"num_attention_heads": 32,
"attention_head_dim": 48,
"norm_type": "ada_norm",
"dropout": 0.2,
"final_dropout": True,
"output_dim": 1024,
"interleave_self_attention": True,
},
"vl_self_attention_cfg": {
"positional_embeddings": None,
"num_layers": 4,
"num_attention_heads": 32,
"attention_head_dim": 64,
"dropout": 0.2,
"final_dropout": True,
},
"num_inference_timesteps": 4,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
"tune_projector": True,
"tune_diffusion_model": True,
"tune_vlln": True,
"state_dropout_prob": 0.2,
"exclude_state": False,
"use_mean_std": False,
"max_num_embodiments": 32,
"rtc_ramp_rate": 6.0,
}
class GR00TN17Config(PretrainedConfig):
"""Configuration for NVIDIA GR00T N1.7.
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
flow-matching action head. This mirrors the public N1.7 checkpoint config
while keeping it local to LeRobot and independent from the external
Isaac-GR00T ``gr00t`` Python package.
"""
model_type = "Gr00tN1d7"
_defaults = GR00T_N1_7_DEFAULTS
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in GR00T_N1_7_DEFAULTS.items():
setattr(self, key, deepcopy(kwargs.pop(key, value)))
for key, value in kwargs.items():
setattr(self, key, value)
class CategorySpecificLinear(nn.Module):
"""Linear layer with category-specific weights for multi-embodiment support."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
super().__init__()
self.num_categories = num_categories
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
"""Two-layer MLP with category-specific weights."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
The frequency scalar is intentionally created on CPU and then broadcast with
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
embedding and avoids tiny dtype/device construction differences in parity
tests.
"""
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class MultiEmbodimentActionEncoder(nn.Module):
"""Action encoder with category-specific projections and sinusoidal time encoding."""
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
super().__init__()
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
batch_size, horizon, _ = actions.shape
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
raise ValueError("Expected `timesteps` to have shape (B,).")
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
action_emb = self.W1(actions, cat_ids)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
return self.W3(x, cat_ids)
class Qwen3Backbone(nn.Module):
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
The public checkpoint stores the action head in the GR00T checkpoint but
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
keeps the nested HF module layout compatible across transformer versions
and exposes the hidden states consumed by the action head.
"""
def __init__(
self,
model_name: str = "nvidia/Cosmos-Reason2-2B",
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
tune_top_llm_layers: int = 0,
trainable_params_fp32: bool = False,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_pretrained_weights: bool = True,
):
require_package("transformers", extra="groot")
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
extra_kwargs: dict[str, Any] = {}
if use_flash_attention:
try:
import flash_attn # noqa: F401
extra_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
extra_kwargs["attn_implementation"] = "sdpa"
if load_bf16:
extra_kwargs["torch_dtype"] = torch.bfloat16
if load_pretrained_weights:
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
**extra_kwargs,
**transformers_loading_kwargs,
).eval()
else:
self.model = self._from_backbone_config(
model_name=model_name,
model_kwargs=extra_kwargs,
config_kwargs=transformers_loading_kwargs,
).eval()
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
if load_bf16 and trainable_params_fp32:
for parameter in self.parameters():
if parameter.requires_grad:
parameter.data = parameter.data.to(torch.float32)
def set_trainable_parameters(
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
) -> None:
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_llm:
self.language_model.requires_grad_(False)
if not tune_visual:
self.visual.requires_grad_(False)
if tune_top_llm_layers > 0:
for layer in self.language_model.layers[-tune_top_llm_layers:]:
for parameter in layer.parameters():
parameter.requires_grad = True
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if self.language_model and not self.tune_llm:
self.language_model.eval()
if self.visual and not self.tune_visual:
self.visual.eval()
@property
def language_model(self) -> nn.Module:
return getattr(self.model, "model", self.model).language_model
@property
def visual(self) -> nn.Module:
return getattr(self.model, "model", self.model).visual
def _from_backbone_config(
self,
*,
model_name: str,
model_kwargs: dict[str, Any],
config_kwargs: dict[str, Any],
) -> nn.Module:
if _is_cosmos_reason2_backbone(model_name):
backbone_config = _cosmos_reason2_qwen3_vl_config()
else:
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
if "mm_token_type_ids" in model_input:
return
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
return
input_ids = model_input.get("input_ids")
if input_ids is None:
return
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None:
mm_token_type_ids[input_ids == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[input_ids == video_token_id] = 2
model_input["mm_token_type_ids"] = mm_token_type_ids
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
drops text position ids before calling text-layer flash attention. GR00T
N1.7 was aligned against the older Transformers path, where a fourth text
position row is forwarded alongside the temporal/height/width rows. Adding
the row here preserves the newer multimodal position computation while
keeping flash attention on the legacy code path.
"""
if "position_ids" in model_input:
return
qwen3_model = getattr(self.model, "model", self.model)
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
if compute_3d_position_ids is None:
return
position_ids = compute_3d_position_ids(
input_ids=model_input.get("input_ids"),
image_grid_thw=model_input.get("image_grid_thw"),
video_grid_thw=model_input.get("video_grid_thw"),
inputs_embeds=None,
attention_mask=model_input.get("attention_mask"),
past_key_values=None,
mm_token_type_ids=model_input.get("mm_token_type_ids"),
)
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
model_input["position_ids"] = position_ids
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
Newer releases expose the post-final-norm tensor there instead. Capturing
the last decoder layer output directly keeps the N1.7 action head input
stable across Transformers versions.
"""
captured: dict[str, torch.Tensor] = {}
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
if isinstance(output, torch.Tensor):
captured["features"] = output
elif isinstance(output, (tuple, list)) and output:
captured["features"] = output[0]
elif hasattr(output, "last_hidden_state"):
captured["features"] = output.last_hidden_state
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
try:
outputs = self.model(**model_input, output_hidden_states=True)
finally:
hook.remove()
return captured.get("features", outputs.hidden_states[-1])
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
model_input = {key: vl_input[key] for key in keys_to_use}
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
self._ensure_mm_token_type_ids(model_input)
self._ensure_legacy_qwen3_position_ids(model_input)
features = self._last_decoder_layer_output(model_input)
image_mask = model_input["input_ids"] == self.model.config.image_token_id
attention_mask = model_input["attention_mask"] == 1
return BatchFeature(
data={
"backbone_features": features,
"backbone_attention_mask": attention_mask,
"image_mask": image_mask,
}
)
class GR00TN17ActionHead(nn.Module):
supports_gradient_checkpointing = True
def __init__(self, config: GR00TN17Config):
require_package("diffusers", extra="groot")
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
if config.use_alternate_vl_dit:
self.model = AlternateVLDiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
)
else:
self.model = DiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
)
self.action_dim = config.max_action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim * config.state_history_length,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
else:
self.vl_self_attention = nn.Identity()
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.state_dropout_prob = config.state_dropout_prob
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
def set_trainable_parameters(
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
) -> None:
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vlln = tune_vlln
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vlln:
self.vlln.requires_grad_(False)
self.vl_self_attention.requires_grad_(False)
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vlln:
self.vlln.eval()
self.vl_self_attention.eval()
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if self._beta_dist is None:
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = self.vlln(backbone_output["backbone_features"])
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
vl_embeds = backbone_output.backbone_features
device = vl_embeds.device
embodiment_id = action_input.embodiment_id
if action_input.state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = action_input.state.view(action_input.state.shape[0], 1, -1)
state_features = self.state_encoder(state, embodiment_id)
if self.training and self.state_dropout_prob > 0:
do_dropout = (
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
)
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None]
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
data={
"loss": loss,
"action_loss": action_loss,
"action_mask": action_mask,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
state = action_input.state
if state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = state.view(state.shape[0], 1, -1)
state_features = self.state_encoder(state, action_input.embodiment_id)
return BatchFeature(
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
)
@torch.no_grad()
def get_action_with_features(
self,
backbone_features: torch.Tensor,
state_features: torch.Tensor,
embodiment_id: torch.Tensor,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
vl_embeds = backbone_features
batch_size = vl_embeds.shape[0]
device = vl_embeds.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=vl_embeds.dtype,
device=device,
)
dt = 1.0 / self.num_inference_timesteps
vel_strength = torch.ones_like(actions)
if "action" in action_input:
if options is None:
raise ValueError("RTC options are required when action is provided to get_action.")
action_horizon_before_padding = options["action_horizon"]
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
:,
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
:,
]
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
ramp = ramp / ramp[-1].clamp_min(1e-8)
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
None, :, None
].to(device)
for t_step in range(self.num_inference_timesteps):
t_cont = t_step / float(self.num_inference_timesteps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
@torch.no_grad()
def get_action(
self,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
features = self._encode_features(backbone_output, action_input)
return self.get_action_with_features(
backbone_features=features.backbone_features,
state_features=features.state_features,
embodiment_id=action_input.embodiment_id,
backbone_output=backbone_output,
action_input=action_input,
options=options,
)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
return Qwen3VLConfig(
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=True,
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
)
def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone
if config.backbone_model_type == "qwen":
logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
config.model_name,
)
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
class GR00TN17(PreTrainedModel):
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
config_class = GR00TN17Config
supports_gradient_checkpointing = True
def __init__(
self,
config: GR00TN17Config,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_backbone_weights: bool = True,
):
_register_with_transformers()
super().__init__(config)
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
self.config = config
backbone_cls = get_backbone_cls(config)
self.backbone = backbone_cls(
model_name=config.model_name,
tune_llm=config.tune_llm,
tune_visual=config.tune_visual,
select_layer=config.select_layer,
reproject_vision=config.reproject_vision,
use_flash_attention=config.use_flash_attention,
load_bf16=config.load_bf16,
tune_top_llm_layers=config.tune_top_llm_layers,
trainable_params_fp32=config.backbone_trainable_params_fp32,
transformers_loading_kwargs=transformers_loading_kwargs,
load_pretrained_weights=load_backbone_weights,
)
self.action_head = GR00TN17ActionHead(config)
self.post_init()
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
global tree
if tree is None:
require_package("dm-tree", extra="groot", import_name="tree")
tree = importlib.import_module("tree")
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_dtype(x):
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
return x.to(self.device, dtype=self.dtype)
return x.to(self.device)
return (
tree.map_structure(to_device_with_dtype, backbone_inputs),
tree.map_structure(to_device_with_dtype, action_inputs),
)
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head(backbone_outputs, action_inputs)
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head.get_action(backbone_outputs, action_inputs, options)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
tune_vlln = kwargs.pop("tune_vlln", True)
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
try:
local_model_path = snapshot_download(
pretrained_model_name_or_path,
repo_type="model",
revision=kwargs.get("revision"),
cache_dir=kwargs.get("cache_dir"),
local_files_only=kwargs.get("local_files_only", False),
token=kwargs.get("token"),
)
except (HFValidationError, RepositoryNotFoundError):
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path,
transformers_loading_kwargs=transformers_loading_kwargs,
load_backbone_weights=load_backbone_weights,
**kwargs,
)
pretrained_model.backbone.set_trainable_parameters(
tune_visual=tune_visual,
tune_llm=tune_llm,
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector,
tune_diffusion_model=tune_diffusion_model,
tune_vlln=tune_vlln,
)
return pretrained_model
def _register_with_transformers() -> None:
"""Register GR00T N1.7 with transformers' Auto* factories.
Idempotent: ``register(..., exist_ok=True)`` makes repeat calls no-ops (with a fallback that
suppresses the already-registered error on transformers builds whose ``register()`` predates
``exist_ok``), so no run-once guard is needed.
"""
if AutoConfig is None or AutoModel is None:
return
try:
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
try:
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoModel.register(GR00TN17Config, GR00TN17)
+104 -237
View File
@@ -17,22 +17,28 @@
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code. Dataset loading and training
orchestration are handled by LeRobot's standard training stack.
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import builtins
import logging
import os
from collections import deque
from pathlib import Path
from typing import TypeVar
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
@@ -40,19 +46,8 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
)
from .groot_n1_7 import GR00TN17
logger = logging.getLogger(__name__)
from .configuration_groot import GrootConfig
from .groot_n1 import GR00TN15
T = TypeVar("T", bound="GrootPolicy")
@@ -72,38 +67,37 @@ class GrootPolicy(PreTrainedPolicy):
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self._action_queue_steps = self._resolve_action_queue_steps()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T N1.7 model using the ported components."""
model_kwargs = {
"pretrained_model_name_or_path": self.config.base_model_path,
"tune_llm": self.config.tune_llm,
"tune_visual": self.config.tune_visual,
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
# Forwarded as a GR00TN17Config override; read back by set_trainable_parameters.
"tune_top_llm_layers": self.config.tune_top_llm_layers,
"use_flash_attention": self.config.use_flash_attention,
}
# Surface the inference-time knobs onto the model config only when the user set them; None
# leaves the value baked into the checkpoint untouched.
if self.config.num_inference_timesteps is not None:
model_kwargs["num_inference_timesteps"] = self.config.num_inference_timesteps
if self.config.rtc_ramp_rate is not None:
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
"""Create and initialize the GR00T model using Isaac-GR00T API.
return GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=self.config.tune_vlln,
transformers_loading_kwargs={"trust_remote_code": True},
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self._action_queue_steps)
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
@@ -124,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T N1.7 models - loads the raw model
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
@@ -143,11 +137,13 @@ class GrootPolicy(PreTrainedPolicy):
Returns:
Initialized GrootPolicy instance with loaded model
"""
requested_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
pretrained_name_or_path,
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
@@ -178,7 +174,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -194,13 +190,11 @@ class GrootPolicy(PreTrainedPolicy):
)
# This is a base GR00T model - load it fresh
logger.info("Detected base GR00T model, loading from HuggingFace...")
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
# Create default config with the pretrained path
config = GrootConfig(
base_model_path=str(pretrained_name_or_path),
)
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
@@ -221,15 +215,6 @@ class GrootPolicy(PreTrainedPolicy):
if hasattr(config, key):
setattr(config, key, value)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != GROOT_N1_7:
message = (
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -240,160 +225,21 @@ class GrootPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
execution_horizon = infer_groot_n1_7_action_execution_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
horizons = [n_action_steps]
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
horizons = [actions.shape[1]]
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
for horizon in (self.config.chunk_size, self.config.n_action_steps):
horizon = int(horizon)
if horizon > 0:
horizons.append(horizon)
return max(1, min(horizons))
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
allowed_base = {"state", "state_mask", "embodiment_id"}
if include_action:
allowed_base.update({"action", "action_mask"})
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
return {
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
self,
inputs: dict[str, Tensor],
*,
inference_delay: object,
prev_chunk_left_over: object,
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if prev_chunk_left_over is None:
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
if prev_chunk_left_over.numel() == 0:
return inputs, None
prev_actions = prev_chunk_left_over
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
state = inputs.get("state")
if state is None:
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
batch_size = state.shape[0]
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
# provided prefix row as a real action constraint, so strip that padding
# before constructing the native overlap options.
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
if valid_prefix_rows.any():
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
prev_actions = prev_actions[:, :valid_prefix_steps, :]
else:
return inputs, None
model_action_horizon = int(
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
action_horizon = int(prev_actions.shape[1])
if action_horizon <= 0:
return inputs, None
if prev_actions.shape[2] > max_action_dim:
prev_actions = prev_actions[:, :, :max_action_dim]
elif prev_actions.shape[2] < max_action_dim:
pad = torch.zeros(
prev_actions.shape[0],
prev_actions.shape[1],
max_action_dim - prev_actions.shape[2],
dtype=prev_actions.dtype,
device=prev_actions.device,
)
prev_actions = torch.cat([prev_actions, pad], dim=2)
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
rtc_config = getattr(self.config, "rtc_config", None)
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
overlap_steps = max(0, min(action_horizon, execution_horizon))
if overlap_steps == 0:
return inputs, None
try:
frozen_steps = int(inference_delay or 0)
except (TypeError, ValueError):
frozen_steps = 0
frozen_steps = max(0, min(frozen_steps, overlap_steps))
options = {
"action_horizon": action_horizon,
"rtc_overlap_steps": overlap_steps,
"rtc_frozen_steps": frozen_steps,
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
}
inputs = dict(inputs)
inputs["action"] = prev_actions
return inputs, options
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = get_device_from_parameters(self)
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
@@ -402,52 +248,38 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
action-overlap options before calling the underlying model.
"""
self.eval()
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = get_device_from_parameters(self)
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
if groot_options is not None:
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
else:
outputs = self._groot_model.get_action(groot_inputs)
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
prediction_horizon = self._resolve_prediction_horizon(actions)
actions = actions[:, :prediction_horizon]
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
@@ -460,5 +292,40 @@ class GrootPolicy(PreTrainedPolicy):
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
File diff suppressed because it is too large Load Diff
+36 -245
View File
@@ -1,256 +1,47 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared, side-effect-free utilities for the GR00T N1.7 policy.
These helpers are consumed by both the config layer (checkpoint sidecar
inspection) and the processor layer (stat flattening, action decoding, language
and image packing). They are pure functions with no GR00T-specific state so they
can be unit-tested in isolation and reused without importing the heavier
config/processor modules.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from shutil import copytree
import numpy as np
import torch
from huggingface_hub import hf_hub_download
def read_json(path: Path) -> dict[str, Any]:
"""Read a JSON object from ``path``, returning ``{}`` on any read/parse error."""
try:
with path.open() as f:
data = json.load(f)
except (OSError, json.JSONDecodeError):
return {}
return data if isinstance(data, dict) else {}
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
def as_int_pair(value: Any) -> list[int] | None:
if not isinstance(value, (list, tuple)) or len(value) != 2:
return None
try:
return [int(value[0]), int(value[1])]
except (TypeError, ValueError):
return None
def as_optional_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def as_optional_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def as_float_list(values: Any) -> list[float]:
if values is None:
return []
if isinstance(values, torch.Tensor):
return values.detach().cpu().reshape(-1).float().tolist()
if isinstance(values, np.ndarray):
return values.reshape(-1).astype(np.float32).tolist()
if isinstance(values, (list, tuple)):
flattened: list[float] = []
for value in values:
flattened.extend(as_float_list(value))
return flattened
return [float(values)]
def config_value(value: Any) -> str:
if hasattr(value, "value"):
value = value.value
text = str(value).lower()
return {
"relative": "relative",
"absolute": "absolute",
"delta": "delta",
"eef": "eef",
"non_eef": "non_eef",
"default": "default",
"xyz_rot6d": "xyz+rot6d",
"xyz+rot6d": "xyz+rot6d",
"xyz_rotvec": "xyz+rotvec",
"xyz+rotvec": "xyz+rotvec",
}.get(text, text)
def has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
if not stats:
return False
return any(bool(modality_stats) for modality_stats in stats.values())
def stat_dim_from_entry(entry: dict[str, Any]) -> int:
for stat_name in ("mean", "q01", "min", "max", "std"):
value = entry.get(stat_name)
if isinstance(value, list) and len(value) > 0:
return len(value)
return 0
def flatten_n1_7_modality_stats(
*,
embodiment_stats: dict[str, Any],
embodiment_config: dict[str, Any],
modality: str,
use_percentiles: bool,
use_relative_action: bool,
) -> dict[str, list[float]]:
"""Flatten one N1.7 modality's grouped statistics in checkpoint order.
When checkpoints request percentile normalization, q01/q99 replace min/max
for regular groups. Relative action groups read from ``relative_action``
stats and keep min/max, matching Isaac-GR00T's processor override.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
source_stats = embodiment_stats.get(modality, {})
modality_config = embodiment_config.get(modality, {})
if not isinstance(source_stats, dict) or not isinstance(modality_config, dict):
return {}
modality_keys = modality_config.get("modality_keys", [])
if not isinstance(modality_keys, list):
return {}
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
flattened: dict[str, list[float]] = {}
action_configs = modality_config.get("action_configs", []) if modality == "action" else []
if not isinstance(action_configs, list):
action_configs = []
relative_stats = embodiment_stats.get("relative_action", {})
if not isinstance(relative_stats, dict):
relative_stats = {}
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
for stat_name in ("min", "max", "mean", "std"):
values: list[float] = []
source_stat_name = stat_name
if use_percentiles and stat_name == "min":
source_stat_name = "q01"
elif use_percentiles and stat_name == "max":
source_stat_name = "q99"
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for idx, modality_key in enumerate(modality_keys):
if not isinstance(modality_key, str):
continue
key_source_stats = source_stats
key_stat_name = source_stat_name
if modality == "action" and use_relative_action and idx < len(action_configs):
action_config = action_configs[idx]
if isinstance(action_config, dict) and config_value(action_config.get("rep")) == "relative":
key_source_stats = relative_stats
key_stat_name = stat_name
key_stats = key_source_stats.get(modality_key, {})
if not isinstance(key_stats, dict):
raise KeyError(f"Missing statistics for {modality}.{modality_key}")
raw_values = key_stats.get(key_stat_name)
if raw_values is None:
raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}")
values.extend(as_float_list(raw_values))
if values:
flattened[stat_name] = values
return flattened
def rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
rows = rot6d.reshape(2, 3).astype(np.float64)
row1 = rows[0] / np.linalg.norm(rows[0])
row2 = rows[1] - np.dot(row1, rows[1]) * row1
row2 = row2 / np.linalg.norm(row2)
row3 = np.cross(row1, row2)
return np.vstack([row1, row2, row3])
def xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray:
transform = np.eye(4, dtype=np.float64)
transform[:3, :3] = rot6d_to_matrix(xyz_rot6d[3:])
transform[:3, 3] = xyz_rot6d[:3]
return transform
def homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray:
return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0)
def relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray:
"""Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses."""
out = np.empty_like(action, dtype=np.float64)
for batch_idx in range(action.shape[0]):
reference = xyz_rot6d_to_homogeneous(reference_state[batch_idx])
for timestep in range(action.shape[1]):
relative = xyz_rot6d_to_homogeneous(action[batch_idx, timestep])
out[batch_idx, timestep] = homogeneous_to_xyz_rot6d(reference @ relative)
return out.astype(np.float32)
def infer_n1_7_batch_size_and_device(
obs: dict[str, Any], action: torch.Tensor | None
) -> tuple[int, torch.device]:
for value in list(obs.values()) + [action]:
if isinstance(value, torch.Tensor):
return value.shape[0], value.device
video = obs.get("video")
if isinstance(video, np.ndarray):
return video.shape[0], torch.device("cpu")
return 1, torch.device("cpu")
def prepare_n1_7_language_batch(
language: Any,
batch_size: int,
*,
formalize_language: bool,
) -> list[str]:
default_language = "Perform the task."
if language is None or (isinstance(language, str) and language == ""):
languages = [default_language] * batch_size
elif isinstance(language, str):
languages = [language] * batch_size
elif isinstance(language, (list, tuple)):
languages = list(language)
if len(languages) == 0:
languages = [default_language] * batch_size
elif len(languages) == 1 and batch_size > 1:
languages = languages * batch_size
elif len(languages) != batch_size:
raise ValueError(
f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}."
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)
else:
languages = [str(language)] * batch_size
formatted = []
for item in languages:
text = str(item) if item else default_language
if formalize_language:
text = text.lower()
text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_")
formatted.append(text)
return formatted
+279 -55
View File
@@ -32,7 +32,6 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
This method does the actual saving work and is called by HubMixin.save_pretrained.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
"""
config_filename = kwargs.pop("config_filename", None)
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
config: dict[str, Any] = {
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_entry["config"] = processor_step.get_config()
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
pipeline_config["steps"].append(step_entry)
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
return pipeline_config
config["steps"].append(step_entry)
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
return cls(
pipeline = cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config dictionary (guaranteed non-None)
loaded_config: The loaded config value to validate (may be non-dict)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
)
@classmethod
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
# 3. Load step state if available
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
return steps, remaining_override_keys
steps.append(step_instance)
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
return steps, override_keys
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
processor_steps.append(processor_step)
return processor_steps, remaining_override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: dict) -> bool:
def _is_processor_config(cls, config: Any) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
+4
View File
@@ -23,6 +23,7 @@ from .configs import (
DAggerKeyboardConfig,
DAggerPedalConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
HighlightStrategyConfig,
RolloutConfig,
RolloutStrategyConfig,
@@ -49,6 +50,7 @@ from .inference import (
from .strategies import (
BaseStrategy,
DAggerStrategy,
EpisodicStrategy,
HighlightStrategy,
RolloutStrategy,
SentryStrategy,
@@ -66,6 +68,8 @@ __all__ = [
"HardwareContext",
"HighlightStrategy",
"HighlightStrategyConfig",
"EpisodicStrategy",
"EpisodicStrategyConfig",
"InferenceEngine",
"InferenceEngineConfig",
"PolicyContext",
+36 -1
View File
@@ -121,6 +121,35 @@ class DAggerPedalConfig:
upload: str = "KEY_C"
@RolloutStrategyConfig.register_subclass("episodic")
@dataclass
class EpisodicStrategyConfig(RolloutStrategyConfig):
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
Keyboard controls:
Right arrow end current episode or reset phase early
Left arrow discard current episode and re-record
Escape stop recording session
In between episodes:
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
- else, the robot is moved smoothly to the position of the teleop leader.
"""
# This only applies if there are no teleop leaders specified.
# When True (default), moves the robot back to the joint positions captured at startup.
# Otherwise, leave the robot in its current position.
reset_to_initial_position: bool = True
# Whether to turn on or off the leader -> follower smooth handover behavior.
# When False, fallback to follower -> leader handover.
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
smooth_leader_to_follower_handover: bool = True
@RolloutStrategyConfig.register_subclass("dagger")
@dataclass
class DAggerStrategyConfig(RolloutStrategyConfig):
@@ -229,7 +258,13 @@ class RolloutConfig:
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
self.strategy,
(
SentryStrategyConfig,
HighlightStrategyConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
),
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
@@ -17,6 +17,7 @@
from .base import BaseStrategy
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
from .episodic import EpisodicStrategy
from .factory import create_strategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -27,6 +28,7 @@ __all__ = [
"DAggerPhase",
"DAggerStrategy",
"HighlightStrategy",
"EpisodicStrategy",
"RolloutStrategy",
"SentryStrategy",
"create_strategy",
+14 -69
View File
@@ -56,10 +56,14 @@ from typing import Any
import numpy as np
from lerobot.common.control_utils import is_headless
from lerobot.common.control_utils import (
follower_smooth_move_to,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
from lerobot.teleoperators import Teleoperator
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.import_utils import _pynput_available
@@ -69,7 +73,6 @@ from lerobot.utils.utils import log_say
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
from ..context import RolloutContext
from ..robot_wrapper import ThreadSafeRobot
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
PYNPUT_AVAILABLE = _pynput_available
@@ -171,64 +174,6 @@ class DAggerEvents:
self.upload_requested.clear()
# ---------------------------------------------------------------------------
# Teleoperator helpers
# ---------------------------------------------------------------------------
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
"""Return True when the teleop can receive position feedback (is actuated).
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
"""
return (
bool(teleop.feedback_features)
and hasattr(teleop, "disable_torque")
and hasattr(teleop, "enable_torque")
)
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
) -> None:
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
Requires the teleoperator to support feedback
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
"""
teleop.enable_torque()
current = teleop.get_action()
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
}
teleop.send_feedback(interp)
time.sleep(1 / fps)
def _follower_smooth_move_to(
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
) -> None:
"""Smoothly move the follower robot from ``current`` to ``target`` action.
Used when the teleop is non-actuated: instead of driving the leader arm
to the follower, we bring the follower to the teleop's current pose.
Both ``current`` and ``target`` must be in robot-action key space.
"""
steps = max(int(duration_s * fps), 1)
for step in range(steps + 1):
t = step / steps
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
robot.send_action(interp)
time.sleep(1 / fps)
# ---------------------------------------------------------------------------
# Input device handlers
# ---------------------------------------------------------------------------
@@ -756,31 +701,31 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Pausing engine - robot holds position")
engine.pause()
if _teleop_supports_feedback(teleop) and prev_action is not None:
if teleop_supports_feedback(teleop) and prev_action is not None:
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
logger.info("Smooth handover: moving leader arm to follower position")
_teleop_smooth_move_to(teleop, prev_action)
teleop_smooth_move_to(teleop, prev_action)
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
logger.info("Entering correction mode - human teleop control")
if not _teleop_supports_feedback(teleop) and prev_action is not None:
if not teleop_supports_feedback(teleop) and prev_action is not None:
logger.info("Smooth handover: sliding follower to teleop position")
obs = robot.get_observation()
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
_follower_smooth_move_to(robot, prev_action, target)
follower_smooth_move_to(robot, prev_action, target)
# unlock the teleop for human control
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.enable_torque()
elif new_phase == DAggerPhase.AUTONOMOUS:
@@ -790,7 +735,7 @@ class DAggerStrategy(RolloutStrategy):
engine.resume()
# release teleop before resuming the policy
if _teleop_supports_feedback(teleop):
if teleop_supports_feedback(teleop):
teleop.disable_torque()
# ------------------------------------------------------------------
+335
View File
@@ -0,0 +1,335 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
- Policy drives the robot during each recording episode.
- An optional teleoperator can drive the robot during reset phases so the
operator can bring the environment back to its starting configuration.
If no teleop is connected the robot stays in its current position.
- Keyboard controls:
Right arrow end the current episode or reset phase early
Left arrow discard the current episode and re-record it
Escape stop the recording session
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
"""
from __future__ import annotations
import contextlib
import logging
import time
from lerobot.common.control_utils import (
follower_smooth_move_to,
init_keyboard_listener,
is_headless,
teleop_smooth_move_to,
teleop_supports_feedback,
)
from lerobot.datasets import VideoEncodingManager
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import log_rerun_data
from ..configs import EpisodicStrategyConfig
from ..context import RolloutContext
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
logger = logging.getLogger(__name__)
class EpisodicStrategy(RolloutStrategy):
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
follows every episode (except the last) so the operator can manually
reset the environment. During the reset phase, an optional teleoperator
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
The policy state (hidden state, RTC queue, interpolator) is reset at
the start of each recording episode.
Keyboard events:
right arrow end current episode or reset phase early
left arrow discard & re-record current episode
ESC stop the session
"""
config: EpisodicStrategyConfig
def __init__(self, config: EpisodicStrategyConfig) -> None:
super().__init__(config)
self._listener = None
self._events: dict | None = None
def setup(self, ctx: RolloutContext) -> None:
"""Start the inference engine and attach the keyboard listener."""
self._init_engine(ctx)
self._listener, self._events = init_keyboard_listener()
logger.info("Episodic strategy ready")
def run(self, ctx: RolloutContext) -> None:
"""Main multi-episode recording loop."""
cfg = ctx.runtime.cfg
dataset_cfg = cfg.dataset
robot = ctx.hardware.robot_wrapper
teleop = ctx.hardware.teleop
dataset = ctx.data.dataset
events = self._events
features = ctx.data.dataset_features
fps = cfg.fps
episode_time_s = dataset_cfg.episode_time_s
reset_time_s = dataset_cfg.reset_time_s
num_episodes = dataset_cfg.num_episodes
single_task = dataset_cfg.single_task or cfg.task
play_sounds = cfg.play_sounds
display_compressed = (
True
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
else cfg.display_compressed_images
)
with VideoEncodingManager(dataset):
try:
recorded_episodes = 0
while recorded_episodes < num_episodes and not events["stop_recording"]:
if ctx.runtime.shutdown_event.is_set():
break
# Reset policy state at episode start (discard leftover hidden state / queue)
self._engine.reset()
self._interpolator.reset()
self._engine.resume()
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
self._policy_loop(
ctx=ctx,
robot=robot,
events=events,
features=features,
fps=fps,
control_time_s=episode_time_s,
dataset=dataset,
single_task=single_task,
)
# Reset phase, skip after the last episode (but run when re-recording)
if not events["stop_recording"] and (
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
if teleop:
# Smooth handover so the transition to teleop control is jerk-free.
# For actuated teleops: drive the leader arm to the follower's current
# position so the operator takes over without fighting the arm.
# For non-actuated teleops: slide the follower to the teleop's current
# pose instead, since the leader cannot be driven.
obs = robot.get_observation()
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
if (
teleop_supports_feedback(teleop)
and self.config.smooth_leader_to_follower_handover
):
logger.info("Smooth handover: moving leader arm to follower position")
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
teleop.disable_torque()
else:
logger.info("Smooth handover: sliding follower to teleop position")
teleop_action = teleop.get_action()
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
target = ctx.processors.robot_action_processor((processed, obs))
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
elif self.config.reset_to_initial_position:
# No teleop: return the robot to its startup position.
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
self._reset_loop(
ctx=ctx,
robot=robot,
teleop=teleop,
events=events,
fps=fps,
control_time_s=reset_time_s,
display_data=cfg.display_data,
display_compressed=display_compressed,
)
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
# returns to its initial joint positions captured at startup
if not teleop and self.config.reset_to_initial_position:
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
continue
dataset.save_episode()
recorded_episodes += 1
finally:
# Save any frames buffered in the current episode so an unexpected
# exception or KeyboardInterrupt does not silently drop recorded data.
# suppress: save_episode raises if the buffer is empty (nothing to lose).
logger.info("Episodic control loop ended — saving any in-progress episode")
with contextlib.suppress(Exception):
dataset.save_episode()
def _policy_loop(
self,
ctx: RolloutContext,
robot,
events: dict,
features: dict,
fps: float,
control_time_s: float,
dataset,
single_task: str,
) -> None:
"""Policy-driven recording loop for a single episode."""
interpolator = self._interpolator
control_interval = interpolator.get_control_interval(fps)
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
continue
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
if action_dict is not None:
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
if sleep_t < 0:
logger.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
"Dataset frames might be dropped and robot control might be unstable. "
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
"3) CPU starvation"
)
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def _reset_loop(
self,
ctx: RolloutContext,
robot,
teleop,
events: dict,
fps: float,
control_time_s: float,
display_data: bool,
display_compressed: bool,
) -> None:
"""Reset-phase loop: teleop drives the robot if available, no recording."""
processors = ctx.processors
control_interval = 1.0 / fps
timestamp = 0.0
start_t = time.perf_counter()
while timestamp < control_time_s:
loop_start = time.perf_counter()
if events["exit_early"]:
events["exit_early"] = False
break
if ctx.runtime.shutdown_event.is_set():
break
obs = robot.get_observation()
if teleop is not None:
act = teleop.get_action()
act_teleop = processors.teleop_action_processor((act, obs))
robot_action = processors.robot_action_processor((act_teleop, obs))
robot.send_action(robot_action)
if display_data:
obs_processed = processors.robot_observation_processor(obs)
log_rerun_data(
observation=obs_processed,
action=act_teleop,
compress_images=display_compressed,
)
dt = time.perf_counter() - loop_start
sleep_t = control_interval - dt
precise_sleep(max(sleep_t, 0.0))
timestamp = time.perf_counter() - start_t
def teardown(self, ctx: RolloutContext) -> None:
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
cfg = ctx.runtime.cfg
play_sounds = cfg.play_sounds
log_say("Stop recording", play_sounds, blocking=True)
if not is_headless() and self._listener is not None:
self._listener.stop()
if ctx.data.dataset is not None:
logger.info("Finalizing dataset...")
ctx.data.dataset.finalize()
if (
cfg.dataset is not None
and cfg.dataset.push_to_hub
and ctx.data.dataset is not None
and safe_push_to_hub(
ctx.data.dataset,
tags=cfg.dataset.tags,
private=cfg.dataset.private,
)
):
logger.info("Dataset uploaded to hub")
log_say("Dataset uploaded to hub", play_sounds)
self._teardown_hardware(
ctx.hardware,
return_to_initial_position=cfg.return_to_initial_position,
)
log_say("Exiting", play_sounds)
logger.info("Episodic strategy teardown complete")
+6 -1
View File
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
from .base import BaseStrategy
from .core import RolloutStrategy
from .dagger import DAggerStrategy
from .episodic import EpisodicStrategy
from .highlight import HighlightStrategy
from .sentry import SentryStrategy
@@ -42,4 +43,8 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
return HighlightStrategy(config)
if config.type == "dagger":
return DAggerStrategy(config)
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
if config.type == "episodic":
return EpisodicStrategy(config)
raise ValueError(
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
)
+13
View File
@@ -25,6 +25,7 @@ Strategies
--strategy.type=sentry Continuous recording with auto-upload
--strategy.type=highlight Ring buffer + keystroke save
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
--strategy.type=episodic Episode-oriented recording with reset phases
Inference backends
------------------
@@ -111,6 +112,18 @@ Usage examples
--display_data=true \\
--use_torch_compile=true
# Episodic mode — episode-oriented recording with reset phases
lerobot-rollout \\
--strategy.type=episodic \\
--policy.path=user/my_policy \\
--robot.type=so100_follower \\
--robot.port=/dev/ttyACM0 \\
--teleop.type=so100_leader \\
--teleop.port=/dev/ttyACM1 \\
--dataset.repo_id=user/rollout_episodic_data \\
--dataset.num_episodes=20 \\
--dataset.single_task="Grab the cube"
# Resume a previous sentry recording session
lerobot-rollout \\
--strategy.type=sentry \\
@@ -0,0 +1,386 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
import torch
from safetensors import safe_open
from torch import nn
from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
from lerobot.policies import FastWAMConfig, get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.policies.fastwam.modeling_fastwam import FastWAMPolicy
from lerobot.policies.fastwam.processor_fastwam import FastWAMActionToggleProcessorStep
from lerobot.utils.constants import ACTION, OBS_STATE
class FakeFastWAMCore(nn.Module):
def __init__(self):
super().__init__()
self.dit = nn.Linear(2, 2)
def training_loss(self, sample):
assert sample["video"].ndim == 5
assert sample["context"].ndim == 3
return sample[ACTION].sum() * 0.0 + torch.tensor(1.0), {"loss_action": 1.0}
def infer_action(self, **kwargs):
return {"action": torch.ones(1, kwargs["action_horizon"], 3)}
def test_fastwam_is_registered_and_publicly_exported():
cfg = make_policy_config(
"fastwam",
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
assert isinstance(cfg, FastWAMConfig)
assert cfg.type == "fastwam"
assert get_policy_class("fastwam") is FastWAMPolicy
def test_config_validates_features_model_ids_and_saved_auto_route(tmp_path):
cfg = FastWAMConfig()
cfg.save_pretrained(tmp_path)
saved = json.loads((tmp_path / "config.json").read_text())
assert saved["pretrained_path"] is None
assert cfg.image_features["observation.images.image"].type == FeatureType.VISUAL
assert cfg.action_feature.shape == (7,)
assert cfg.robot_state_feature.shape == (8,)
with pytest.raises(ValueError, match="image feature"):
FastWAMConfig(input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))})
with pytest.raises(ValueError, match="tokenizer_model_id"):
FastWAMConfig(tokenizer_model_id="somebody/other-tokenizer")
def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(2, 2),
device="cpu",
toggle_action_dimensions=[-1],
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 2, 2)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
base_model_id=None,
)
dataset_stats = {
"observation.images.image": {
"mean": torch.full((3, 1, 1), 0.2),
"std": torch.full((3, 1, 1), 0.1),
},
OBS_STATE: {
"mean": torch.tensor([1.0, 3.0]),
"std": torch.tensor([2.0, 4.0]),
},
ACTION: {
"mean": torch.zeros(3),
"std": torch.ones(3),
},
}
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_stats)
processed = preprocessor(
{
"observation.images.image": torch.tensor(
[
[[0.0, 0.5], [1.0, 0.5]],
[[0.0, 0.5], [1.0, 0.5]],
[[0.0, 0.5], [1.0, 0.5]],
]
),
OBS_STATE: torch.tensor([3.0, 7.0]),
}
)
preprocessor.save_pretrained(tmp_path, config_filename="policy_preprocessor.json")
postprocessor.save_pretrained(tmp_path, config_filename="policy_postprocessor.json")
_, loaded_postprocessor = make_pre_post_processors(cfg, pretrained_path=str(tmp_path))
expected_image = torch.tensor(
[[[[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]], [[-1.0, 0.0], [1.0, 0.0]]]]
)
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
assert torch.allclose(processed["observation.images.image"], expected_image)
assert torch.allclose(processed[OBS_STATE], torch.tensor([[1.0, 1.0]]))
assert torch.equal(dataset_stats["observation.images.image"]["mean"], torch.full((3, 1, 1), 0.2))
assert any(isinstance(step, FastWAMActionToggleProcessorStep) for step in loaded_postprocessor.steps)
assert torch.equal(
loaded_postprocessor(torch.tensor([[0.25, 0.5, 1.0]])), torch.tensor([[0.25, 0.5, -1.0]])
)
def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
captured = []
class CapturingCore(FakeFastWAMCore):
def infer_action(self, **kwargs):
captured.append(
{
"image_shape": tuple(kwargs["input_image"].shape),
"proprio_shape": tuple(kwargs["proprio"].shape),
"prompt": kwargs["prompt"],
}
)
return {"action": torch.full((1, kwargs["action_horizon"], 3), float(len(captured)))}
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CapturingCore())
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(16, 16),
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(2,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
base_model_id=None,
)
policy = FastWAMPolicy(cfg)
loss, metrics = policy.forward(
{
"observation.images.image": torch.zeros(1, 3, 16, 16),
OBS_STATE: torch.zeros(1, 2),
ACTION: torch.zeros(1, 4, 3),
"context": torch.zeros(1, 5, 4096),
"context_mask": torch.ones(1, 5, dtype=torch.bool),
}
)
action = policy.predict_action_chunk(
{
"observation.images.image": torch.stack(
[
torch.zeros(3, 16, 16),
torch.ones(3, 16, 16),
]
),
OBS_STATE: torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"task": ["task 0", "task 1"],
}
)
assert loss.item() == 1.0
assert metrics["loss_action"] == 1.0
assert action.shape == (2, 4, 3)
assert action[:, 0, 0].tolist() == [1.0, 2.0]
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
assert [item["proprio_shape"] for item in captured] == [(1, 2), (1, 2)]
assert [item["prompt"] for item in captured] == [
cfg.prompt_template.format(task="task 0"),
cfg.prompt_template.format(task="task 1"),
]
class CoreWithFrozenComponents(FakeFastWAMCore):
"""Fake core mirroring the real one: frozen VAE / text encoder held as
*unregistered* attributes (via `object.__setattr__`) so they are excluded from
`state_dict()` and the saved checkpoint, but still moved by the `_apply` override."""
def __init__(self):
super().__init__()
object.__setattr__(self, "vae", nn.Linear(2, 2))
object.__setattr__(self, "text_encoder", nn.Linear(2, 2))
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
def _apply(self, fn, *args, **kwargs):
super()._apply(fn, *args, **kwargs)
self.vae._apply(fn)
self.text_encoder._apply(fn)
return self
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
def build_core(self, config):
core = CoreWithFrozenComponents()
with torch.no_grad():
core.dit.weight.fill_(0.5)
return core
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", build_core)
reference = FastWAMPolicy(cfg)
with torch.no_grad():
reference.model.dit.weight.fill_(1.25) # a distinctive, trained-looking weight
reference.save_pretrained(tmp_path)
# Building from Wan2.2 must never happen on a checkpoint load.
def fail_if_wan_pretrained_is_loaded(*args, **kwargs):
raise AssertionError("from_pretrained must not initialize or download the Wan2.2 backbone")
monkeypatch.setattr(
"lerobot.policies.fastwam.modular_fastwam.FastWAM.from_wan22_pretrained",
fail_if_wan_pretrained_is_loaded,
)
policy = FastWAMPolicy.from_pretrained(tmp_path)
assert isinstance(policy.model, CoreWithFrozenComponents)
# The bundled checkpoint weights overwrote the freshly built (0.5) DiT weights.
assert torch.allclose(policy.model.dit.weight, torch.full_like(policy.model.dit.weight, 1.25))
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
save_dir = tmp_path / "saved"
policy.save_pretrained(save_dir)
assert (save_dir / "model.safetensors").is_file()
# No Wan sidecar files either: the frozen backbone comes from the diffusers repo.
assert not (save_dir / "Wan2.2_VAE.safetensors").exists()
assert not (save_dir / "google").exists()
with safe_open(save_dir / "model.safetensors", framework="pt") as f:
keys = set(f.keys())
# Lean checkpoint: only the trainable DiT is saved; the frozen VAE / UMT5 text
# encoder are excluded (loaded from the diffusers/transformers repos at init).
assert any(key.startswith("model.dit.") for key in keys)
assert not any(key.startswith("model.vae.") for key in keys)
assert not any(key.startswith("model.text_encoder.") for key in keys)
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(
action_dim=3,
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
# Unregistered: excluded from state_dict and from the optimizer's parameter set.
sd = policy.state_dict()
assert not any(k.startswith("model.vae.") or k.startswith("model.text_encoder.") for k in sd)
param_names = [n for n, _ in policy.named_parameters()]
assert not any("vae" in n or "text_encoder" in n for n in param_names)
# ...but the `_apply` override still carries them through `.to()` (dtype stands in
# for device on a CPU box), so they never strand off the rest of the model.
policy.to(torch.float64)
assert policy.model.dit.weight.dtype == torch.float64 # registered
assert policy.model.vae.weight.dtype == torch.float64 # unregistered, moved via _apply
assert policy.model.text_encoder.weight.dtype == torch.float64
def test_pretrained_config_round_trips_fastwam_features(tmp_path):
cfg = FastWAMConfig(action_dim=7, proprio_dim=8, image_size=(224, 448), base_model_id=None)
cfg.save_pretrained(tmp_path)
loaded = PreTrainedConfig.from_pretrained(tmp_path)
assert loaded.type == "fastwam"
assert loaded.image_features["observation.images.image"].type == FeatureType.VISUAL
assert loaded.action_feature.shape == (7,)
assert loaded.robot_state_feature.shape == (8,)
def test_vae_adapter_empty_build_encode_decode_shapes():
"""Offline glue check of the diffusers-backed VAE adapter (random weights).
Validates the encode/decode contract 48 latent channels, 16x spatial / 4x
temporal compression, list-or-batch input, scaling round-trip without any
weight download. (Numerical fidelity vs the original Wan VAE is a separate,
GPU + real-weights verification step.)
"""
pytest.importorskip("diffusers")
from diffusers import AutoencoderKLWan
from lerobot.policies.fastwam.wan_adapters import WanVideoVAE38
# Production always loads a real pretrained VAE from the diffusers repo; here we
# build the same architecture with random weights and dummy standardization stats
# to exercise the adapter's shape/scaling contract offline (fidelity is checked
# separately, with real weights, on GPU).
arch = {
"base_dim": 160,
"decoder_base_dim": 256,
"z_dim": 48,
"dim_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_scales": [],
"temporal_downsample": [False, True, True],
"dropout": 0.0,
"is_residual": True,
"in_channels": 12,
"out_channels": 12,
"patch_size": 2,
"scale_factor_spatial": 16,
"scale_factor_temporal": 4,
"clip_output": False,
"latents_mean": [0.0] * 48,
"latents_std": [1.0] * 48,
}
raw = AutoencoderKLWan.from_config(arch)
vae = WanVideoVAE38(dtype=torch.float32, device="cpu", pretrained=raw)
assert vae.z_dim == 48
assert vae.upsampling_factor == 16
assert vae.temporal_downsample_factor == 4
video = torch.rand(1, 3, 5, 32, 32) * 2 - 1 # [B,C,T,H,W] in [-1,1]
latents = vae.encode(video)
assert latents.shape == (1, 48, 2, 2, 2) # T'=(5-1)//4+1, H'=W'=32//16
decoded = vae.decode(latents)
assert decoded.shape[0] == 1 and decoded.shape[1] == 3 and decoded.shape[-2:] == (32, 32)
assert decoded.min() >= -1.0 and decoded.max() <= 1.0
# list input is accepted and equals the batched path
assert torch.equal(vae.encode([video[0]]), latents)
@@ -1,2 +0,0 @@
# Local-only parity artifacts (regenerated via dump_original_n1_7.py); never committed.
*.npz
+11 -18
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
"""Test script for LeRobot's Groot policy forward and inference passes."""
import gc
import os
@@ -41,20 +41,13 @@ pytestmark = pytest.mark.skipif(
)
# Define constants for dummy data (GR00T N1.7 native conventions).
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
# matches the model's native action space.
# Define constants for dummy data
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 40
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = auto_select_torch_device()
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
EMBODIMENT_TAG = "gr1_unified"
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
def cleanup_memory():
@@ -95,13 +88,13 @@ def instantiate_lerobot_groot(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = EMBODIMENT_TAG
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
@@ -109,7 +102,7 @@ def instantiate_lerobot_groot(
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag=EMBODIMENT_TAG,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
@@ -155,8 +148,8 @@ def create_dummy_data(device=DEVICE):
@require_cuda
def test_lerobot_groot_inference():
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
print("Test: LeRobot GR00T N1.7 Inference Pass")
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
print("Test: LeRobot Groot Inference Pass")
set_seed_all(42)
@@ -188,9 +181,9 @@ def test_lerobot_groot_inference():
@require_cuda
def test_lerobot_groot_forward_pass():
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
"""Test the forward pass of LeRobot's Groot policy."""
print("\n" + "=" * 50)
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
print("Test: LeRobot Groot Forward Pass (Training Mode)")
set_seed_all(42)
File diff suppressed because it is too large Load Diff
+408 -171
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,194 +14,431 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
normalized flow-matching prediction before any action decoding) as NVIDIA's original
``gr00t`` package, given byte-identical pre-processed inputs and the same
flow-matching seed. The comparison is parametrized over every embodiment tag present
in the checkpoint.
To keep the comparison fair, the original outputs + the exact collated inputs are
produced once per embodiment in the original ``gr00t`` env via the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
to per-tag ``.npz`` files.
This test discovers those artifacts, replays the identical inputs through the LeRobot
model, and compares.
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default it looks for artifacts in
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
for the full run procedure.
"""
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import gc
import os
from pathlib import Path
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
pytest.importorskip("gr00t")
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="Requires a local GR00T N1.7 checkpoint + pre-generated artifacts; not for CI.",
reason="This test requires local Groot installation and is not meant for CI",
)
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
SEED = 42
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
from gr00t.data.dataset import ModalityConfig # noqa: E402
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
from gr00t.model.policy import Gr00tPolicy # noqa: E402
# Artifact filenames are original_n1_7_<embodiment_tag>.npz
_ARTIFACT_PREFIX = "original_n1_7_"
_ARTIFACT_SUFFIX = ".npz"
# GR1 humanoid dimensions (from pretrained model metadata)
# The actual GR1 robot has 44 dimensions for both state and action
# GR00TTransform will pad state to 64 and truncate action to 32
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = "cpu"
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
GR1_BODY_PARTS = {
"left_arm": 7,
"left_hand": 6,
"left_leg": 6,
"neck": 3,
"right_arm": 7,
"right_hand": 6,
"right_leg": 6,
"waist": 3,
}
def _artifact_dir() -> Path:
"""Directory holding the per-embodiment .npz artifacts.
Self-contained by default: a sibling ``artifacts/`` directory next to this test.
Override with ``GROOT_N1_7_PARITY_DIR`` (e.g. to point at a scratch location).
The directory is read-only here -- it is populated by ``utils/dump_original_n1_7.py``
run in the original gr00t environment; the test never creates it.
"""
env = os.environ.get("GROOT_N1_7_PARITY_DIR")
if env:
return Path(env)
return Path(__file__).resolve().parent / "artifacts"
def _discover_artifacts() -> list[tuple[str, Path]]:
"""Return [(embodiment_tag, npz_path), ...] for every dumped artifact."""
d = _artifact_dir()
if not d.is_dir():
return []
out = []
for p in sorted(d.glob(f"{_ARTIFACT_PREFIX}*{_ARTIFACT_SUFFIX}")):
tag = p.name[len(_ARTIFACT_PREFIX) : -len(_ARTIFACT_SUFFIX)]
out.append((tag, p))
return out
def _resolve_checkpoint() -> str:
env = os.environ.get("GROOT_N1_7_LIBERO_CKPT")
if env:
if not Path(env).exists():
pytest.skip(f"GROOT_N1_7_LIBERO_CKPT={env} does not exist")
return env
try:
from huggingface_hub import snapshot_download
root = snapshot_download(
"nvidia/GR00T-N1.7-LIBERO",
local_files_only=True,
allow_patterns=["libero_10/*"],
)
except Exception as exc: # noqa: BLE001
pytest.skip(f"GR00T N1.7 LIBERO checkpoint not available locally: {exc}")
ckpt = Path(root) / "libero_10"
if not (ckpt / "config.json").exists():
pytest.skip(f"GR00T N1.7 LIBERO checkpoint incomplete at {ckpt}")
return str(ckpt)
def _load_artifact(path: Path):
data = np.load(path, allow_pickle=True)
original_action = torch.from_numpy(data["action_pred"]).float()
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
inputs = {}
for key in data.files:
if not key.startswith("in::"):
continue
name = key[4:]
arr = data[key]
t = torch.from_numpy(np.asarray(arr))
declared = dtypes.get(key, "")
if "int" in declared or "long" in declared:
t = t.long()
inputs[name] = t
return original_action, inputs
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
"""Rebuild the nested model-input dict from dot-prefixed flat keys."""
nested: dict = {}
for dotted, value in inputs.items():
parts = dotted.split(".")
cur = nested
for p in parts[:-1]:
cur = cur.setdefault(p, {})
cur[parts[-1]] = value
return nested.get("inputs", nested)
@pytest.fixture(scope="module")
def lerobot_model():
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
ckpt = _resolve_checkpoint()
from lerobot.policies.groot.groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
ckpt,
tune_llm=False,
tune_visual=False,
tune_projector=False,
tune_diffusion_model=False,
tune_vlln=False,
transformers_loading_kwargs={"trust_remote_code": True},
)
# fp32 + SDPA on both sides: bf16 + differing attention kernels otherwise introduce
# ~1e-2 numerical noise unrelated to the implementations.
model.compute_dtype = "float32"
model.config.compute_dtype = model.compute_dtype
model.to(device=DEVICE, dtype=torch.float32)
model.eval()
return model
_ARTIFACTS = _discover_artifacts()
@pytest.mark.skipif(
not _ARTIFACTS,
reason=(
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
"env:\n .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py "
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
),
)
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
original_action, flat_inputs = _load_artifact(artifact)
model_inputs = _unflatten(flat_inputs)
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
torch.manual_seed(SEED)
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
with torch.inference_mode():
out = lerobot_model.get_action(model_inputs)
lerobot_action = out["action_pred"].float().cpu()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
t = min(original_action.shape[1], lerobot_action.shape[1])
d = min(original_action.shape[2], lerobot_action.shape[2])
original_action = original_action[:, :t, :d]
lerobot_action = lerobot_action[:, :t, :d]
diff = torch.abs(lerobot_action - original_action)
max_diff = diff.max().item()
print(
f"\n[{embodiment_tag}] shapes lerobot={tuple(lerobot_action.shape)} "
f"original={tuple(original_action.shape)} "
f"max|diff|={max_diff:.6e} mean|diff|={diff.mean().item():.6e}"
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
) -> tuple[
GrootPolicy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_HORIZON,
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_groot_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
)
assert torch.allclose(lerobot_action, original_action, atol=ATOL, rtol=RTOL), (
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
return (policy, preprocessor, postprocessor)
def instantiate_original_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
):
"""Instantiate original Groot policy from NVIDIA's implementation."""
from gr00t.data.transform.concat import ConcatTransform
from gr00t.data.transform.state_action import StateActionToTensor
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
from gr00t.model.transforms import GR00TTransform
video_keys = ["video.ego_view"]
state_keys = [
"state"
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
action_keys = [
"action.left_arm",
"action.right_arm",
"action.left_hand",
"action.right_hand",
"action.left_leg",
"action.right_leg",
"action.neck",
"action.waist",
]
language_keys = ["annotation.human.action.task_description"]
modality_config = {
"video": ModalityConfig(
delta_indices=[0], # Current frame only
modality_keys=video_keys,
),
"state": ModalityConfig(
delta_indices=[0],
modality_keys=state_keys,
),
"action": ModalityConfig(
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
modality_keys=action_keys,
),
"language": ModalityConfig(
delta_indices=[0],
modality_keys=language_keys,
),
}
modality_transform = ComposedModalityTransform(
transforms=[
VideoToTensor(apply_to=video_keys),
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
# State is already a single concatenated key, so no StateActionToTensor needed
# Convert action from numpy to tensor
StateActionToTensor(apply_to=action_keys),
# Concatenate only video and actions (state is already single key)
ConcatTransform(
video_concat_order=video_keys,
state_concat_order=[], # Empty:state is already single key
action_concat_order=action_keys,
),
GR00TTransform(
max_state_dim=64,
max_action_dim=32,
state_horizon=1,
action_horizon=DUMMY_ACTION_HORIZON,
training=False,
),
]
)
policy = Gr00tPolicy(
model_path=model_path,
embodiment_tag=EmbodimentTag.GR1,
modality_config=modality_config,
modality_transform=modality_transform,
device=DEVICE,
)
return policy, modality_config, modality_transform
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 2
prompt = "Pick up the red cube and place it in the bin"
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
batch = {
"observation.state": state,
"action": torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=device, # Action ground truth (for training)
),
"observation.images.ego_view": torch.rand(
batch_size,
3,
IMAGE_SIZE,
IMAGE_SIZE,
dtype=torch.float32,
device=device, # Images in [0, 1] range as expected by LeRobot
),
"task": [prompt for _ in range(batch_size)],
}
return batch
def convert_lerobot_to_original_format(batch, modality_config):
"""Convert LeRobot batch format to original Groot format.
The original Groot expects observations in this format:
{
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
"annotation.<annotation_type>": str or list[str]
}
"""
# Original Groot expects (T, H, W, C) format for images
# LeRobot has (B, C, H, W) format, so we need to convert
observation = {}
for img_key in ["ego_view"]:
lerobot_key = f"observation.images.{img_key}"
if lerobot_key in batch:
img = batch[lerobot_key]
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
# Convert [0, 1] to [0, 255] uint8 as expected by original
img_np = (img_np * 255).astype(np.uint8)
observation[f"video.{img_key}"] = img_np
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
if "observation.state" in batch:
state = batch["observation.state"]
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
observation["state"] = state_np
if "action" in batch:
action = batch["action"]
action_np = action.cpu().numpy()
start_idx = 0
for part_name, part_dim in GR1_BODY_PARTS.items():
end_idx = start_idx + part_dim
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
start_idx = end_idx
if "task" in batch:
task_list = batch["task"]
# GR00TTransform expects language with (B, T) shape for batched data
# Create a (B, T=1) array where each element is the string directly
bsz = len(task_list)
task_array = np.empty((bsz, 1), dtype=object)
for i in range(bsz):
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
observation["annotation.human.action.task_description"] = task_array
return observation
def test_groot_original_vs_lerobot_pretrained():
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
print("\n[LeRobot] Running inference...")
lerobot_policy.eval()
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
print("\n[Original] Running inference...")
original_policy.model.eval()
observation = convert_lerobot_to_original_format(batch, modality_config)
original_obs_transformed = modality_transform(deepcopy(observation))
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
original_model_output = original_policy.model.get_action(original_obs_transformed)
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
# Take first timestep
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
print("Action Comparison:")
diff = lerobot_actions - original_actions
abs_diff = torch.abs(diff)
for batch_idx in range(lerobot_actions.shape[0]):
print(f"\n{'=' * 60}")
print(f"Batch {batch_idx}")
print(f"{'=' * 60}")
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
print("-" * 60)
for action_idx in range(lerobot_actions.shape[1]):
lr_val = lerobot_actions[batch_idx, action_idx].item()
orig_val = original_actions[batch_idx, action_idx].item()
diff_val = abs(lr_val - orig_val)
sign = "+" if (lr_val - orig_val) > 0 else "-"
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
max_diff = abs_diff.max().item()
tolerance = 0.001
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
)
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation
cleanup_memory()
def test_groot_forward_pass_comparison():
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
print("Test: Forward Pass Comparison (Training Mode)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
lerobot_policy.eval()
original_policy.model.eval()
print("\n[LeRobot] Running forward pass...")
batch_lerobot = deepcopy(batch)
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
set_seed_all(42)
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
print(f" Loss: {lerobot_loss.item():.6f}")
print("\n[Original] Running forward pass...")
observation = convert_lerobot_to_original_format(batch, modality_config)
transformed_obs = modality_transform(observation)
if "action" not in transformed_obs:
action_for_forward = batch_lerobot_processed["action"]
action_mask_for_forward = batch_lerobot_processed["action_mask"]
# Match action horizon if needed
if action_for_forward.shape[1] != original_policy.model.action_horizon:
if action_for_forward.shape[1] < original_policy.model.action_horizon:
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
last_action = action_for_forward[:, -1:, :]
padding = last_action.repeat(1, pad_size, 1)
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
mask_padding = torch.zeros(
action_mask_for_forward.shape[0],
pad_size,
action_mask_for_forward.shape[2],
dtype=action_mask_for_forward.dtype,
device=action_mask_for_forward.device,
)
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
else:
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
action_mask_for_forward = action_mask_for_forward[
:, : original_policy.model.action_horizon, :
]
transformed_obs["action"] = action_for_forward
transformed_obs["action_mask"] = action_mask_for_forward
set_seed_all(42)
with torch.no_grad():
original_outputs = original_policy.model.forward(transformed_obs)
original_loss = original_outputs["loss"]
print(f" Loss: {original_loss.item():.6f}")
loss_diff = abs(lerobot_loss.item() - original_loss.item())
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
print("\nLoss Values:")
print(f" LeRobot: {lerobot_loss.item():.6f}")
print(f" Original: {original_loss.item():.6f}")
print(f" Absolute difference: {loss_diff:.6f}")
print(f" Relative difference: {loss_rel_diff:.2f}%")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation, transformed_obs
cleanup_memory()
-1
View File
@@ -1 +0,0 @@
"""Utilities shared by GR00T policy tests."""
@@ -1,198 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
"""Producer (run in the ORIGINAL gr00t env): dump original GR00T N1.7 outputs + inputs.
The original NVIDIA ``gr00t`` package pins ``transformers==4.57.3`` (py3.10) and its
model-config dataclasses are incompatible with the ``transformers==5.x`` that the
LeRobot GR00T N1.7 integration requires. The two implementations therefore cannot be
imported in the same Python process. To keep the parity comparison FAIR, we run the
original model in its native env here and serialize, PER EMBODIMENT TAG:
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
byte-identical tensors -- same image preprocessing, tokenization, normalization),
* the random seed used right before the flow-matching sampler,
* the raw ``action_pred`` tensor returned by ``model.get_action`` (normalized space,
before any per-implementation action decoding).
Inputs are built GENERICALLY from the checkpoint metadata (no per-tag hardcoding):
state keys + dims come from ``statistics.json``; video + language keys come from the
processor's per-embodiment modality configs. This lets us test many embodiment tags
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
``libero_sim``.
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
Usage:
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
--ckpt <path-to-GR00T-N1.7-LIBERO/libero_10> \
--out-dir tests/policies/groot/artifacts \
[--tags libero_sim,oxe_droid_relative_eef_relative_joint,...] \
[--device cuda] [--seed 42]
If --tags is omitted, every embodiment present in the checkpoint statistics is dumped.
"""
import argparse
import json
import os
from pathlib import Path
import numpy as np
import torch
IMAGE_SIZE = 256
BATCH_SIZE = 2
PROMPT = "pick up the black bowl and place it on the plate"
def load_statistics(ckpt: str) -> dict:
with open(os.path.join(ckpt, "statistics.json")) as f:
return json.load(f)
def make_observation(seed: int, video_keys, lang_key, state_spec):
"""Build a dummy observation dict generically from the embodiment metadata."""
rng = np.random.default_rng(seed)
video = {
k: rng.integers(0, 256, (BATCH_SIZE, 1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
for k in video_keys
}
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
# present-but-empty so the processor's state transform finds every expected key.
state = {
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
for k, dim in state_spec
}
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
return {"video": video, "state": state, "language": language}
def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_path):
from gr00t.data.types import MessageType
video_keys = modality_cfg["video"].modality_keys
lang_key = modality_cfg["language"].modality_keys[0]
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
policy.modality_configs = {
k: v for k, v in policy.processor.get_modality_configs()[tag].items() if k != "rl_info"
}
policy.language_key = policy.modality_configs["language"].modality_keys[0]
torch.manual_seed(args.seed)
np.random.seed(args.seed)
unbatched = policy._unbatch_observation(observation)
processed = []
for obs in unbatched:
vla = policy._to_vla_step_data(obs)
processed.append(policy.processor([{"type": MessageType.EPISODE_STEP.value, "content": vla}]))
collated = policy.collate_fn(processed)
def to_dev(x):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(args.device, torch.float32)
if isinstance(x, torch.Tensor):
return x.to(args.device)
if isinstance(x, dict):
return {k: to_dev(v) for k, v in x.items()}
return x
collated = {k: to_dev(v) for k, v in collated.items()}
torch.manual_seed(args.seed)
with torch.inference_mode():
out = fair_model.get_action(**collated)
action_pred = out["action_pred"].float().cpu().numpy()
flat, meta = {}, {}
def flatten(prefix, obj):
if isinstance(obj, torch.Tensor):
arr = obj.float().cpu().numpy() if torch.is_floating_point(obj) else obj.cpu().numpy()
flat[f"in::{prefix}"] = arr
meta[f"in::{prefix}"] = str(obj.dtype)
elif isinstance(obj, dict):
for k, v in obj.items():
flatten(f"{prefix}.{k}" if prefix else k, v)
elif isinstance(obj, (list, tuple)):
flat[f"in::{prefix}"] = np.array(obj, dtype=object)
else:
flat[f"in::{prefix}"] = np.array(obj)
flatten("", collated)
out_path.parent.mkdir(parents=True, exist_ok=True)
np.savez(
out_path,
action_pred=action_pred,
seed=np.array(args.seed),
device=np.array(args.device),
embodiment_tag=np.array(tag),
meta_keys=np.array(list(meta.keys()), dtype=object),
meta_dtypes=np.array(list(meta.values()), dtype=object),
**flat,
)
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--out-dir", required=True, help="directory for per-tag .npz files")
ap.add_argument("--tags", default="", help="comma-separated embodiment tags (default: all in stats)")
ap.add_argument("--device", default="cuda")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
from gr00t.policy.gr00t_policy import Gr00tPolicy
from transformers import AutoConfig, AutoModel
stats = load_statistics(args.ckpt)
requested = [t.strip() for t in args.tags.split(",") if t.strip()] or list(stats.keys())
# Load the policy once (for its processor/preprocessing) on any valid tag.
bootstrap_tag = "libero_sim" if "libero_sim" in stats else requested[0]
policy = Gr00tPolicy(embodiment_tag=bootstrap_tag, model_path=args.ckpt, device=args.device)
all_modality = policy.processor.get_modality_configs()
# Load a FAIR model (SDPA + fp32) once and reuse across tags. Otherwise the
# original checkpoint default (flash_attention_2 + bf16) introduces kernel/rounding
# noise vs the LeRobot env (which has no flash_attn and runs SDPA).
cfg = AutoConfig.from_pretrained(args.ckpt, trust_remote_code=True)
cfg.use_flash_attention = False
cfg.load_bf16 = False
fair_model = AutoModel.from_pretrained(args.ckpt, config=cfg, trust_remote_code=True)
fair_model.to(device=args.device, dtype=torch.float32)
fair_model.eval()
out_dir = Path(args.out_dir)
done, skipped = [], []
for tag in requested:
if tag not in stats or tag not in all_modality:
print(f"[skip] {tag}: not present in checkpoint statistics/modality configs")
skipped.append(tag)
continue
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
try:
dump_one_tag(
policy, fair_model, tag, all_modality[tag], state_spec, args,
out_dir / f"original_n1_7_{tag}.npz",
)
done.append(tag)
except Exception as exc: # noqa: BLE001
print(f"[fail] {tag}: {type(exc).__name__}: {exc}")
skipped.append(tag)
print(f"\nDumped {len(done)} tags: {done}")
if skipped:
print(f"Skipped/failed {len(skipped)} tags: {skipped}")
if __name__ == "__main__":
main()
+220
View File
@@ -24,6 +24,7 @@ from typing import Any
import pytest
import torch
import torch.nn as nn
from safetensors.torch import load_file
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
return features
class MockLazyTensorStateStep(ProcessorStep):
"""Mock step whose tensor state is not present in constructor config."""
def __init__(
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
):
self.name = name
self.scale = scale
self.tensor_state: torch.Tensor | None = None
if initial_value is not None:
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Return the transition unchanged."""
return transition
def get_config(self) -> dict[str, Any]:
"""Return constructor config while intentionally omitting tensor state."""
return {
"name": self.name,
"scale": self.scale,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only after it has been initialized or loaded."""
if self.tensor_state is None:
return {}
return {"tensor_state": self.tensor_state}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load tensor state."""
self.tensor_state = state["tensor_state"].clone()
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Return features unchanged."""
return features
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
"""Registered lazy tensor state step for registry-based serialization tests."""
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
def test_get_config_matches_saved_json():
"""Test that in-memory config matches the config written by save_pretrained."""
stateless_step = MockStep(name="stateless")
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
in_memory_config = pipeline.get_config()
assert pipeline.get_config() == in_memory_config
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
config_path = Path(tmp_dir) / "memory_pipeline.json"
with open(config_path) as file_pointer:
saved_config = json.load(file_pointer)
assert in_memory_config == saved_config
assert "state_file" not in in_memory_config["steps"][0]
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
def test_state_dict_matches_saved_safetensors():
"""Test that in-memory state matches the safetensors written by save_pretrained."""
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
in_memory_state_dict = pipeline.state_dict()
state_filename = "stateful_pipeline_step_0.safetensors"
state_key = "stateful_pipeline_step_0"
assert set(in_memory_state_dict) == {state_key}
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
in_memory_state_dict[state_key]["tensor_state"].add_(1)
assert stateful_step.tensor_state is not None
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
def test_save_pretrained_still_writes_expected_serialization_files():
"""Test that save_pretrained keeps the existing config and state filenames."""
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
save_path = Path(tmp_dir)
assert (save_path / "policy_preprocessor.json").exists()
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
def test_from_config_round_trips_stateful_pipeline():
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert len(loaded_pipeline) == 1
assert isinstance(loaded_step, MockLazyTensorStateStep)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
def test_from_config_round_trips_registered_stateful_pipeline():
"""Test that from_config resolves registry steps and loads their named tensor state."""
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
assert config["steps"][0]["state_file"] == state_filename
assert set(pipeline_state_dict) == {state_key}
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
assert loaded_step.tensor_state is not None
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
def test_from_config_preserves_state_metadata_for_empty_initial_state():
"""Test in-memory loading when rebuilt steps start without tensor state."""
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.state_dict() == {}
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
loaded_pipeline.load_state_dict(pipeline_state_dict)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
def test_from_config_applies_overrides_before_state_loading():
"""Test that constructor overrides and tensor state loading are separate operations."""
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(
config,
state_dict=pipeline_state_dict,
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.scale == 5.0
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
def test_load_state_dict_raises_on_missing_expected_state():
"""Test loading raises when serialized config expects missing state."""
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
loaded_pipeline.load_state_dict({})
def test_load_state_dict_raises_on_unexpected_extra_state():
"""Test loading raises on unexpected top-level state keys."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
with pytest.raises(KeyError, match="extra"):
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
"""Test stateless in-memory serialization and loading."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
config = pipeline.get_config()
config_without_name = {"steps": config["steps"]}
assert pipeline.state_dict() == {}
assert all("state_file" not in step_entry for step_entry in config["steps"])
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
assert loaded_pipeline.name == "DataProcessorPipeline"
assert loaded_pipeline.state_dict() == {}
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
def test_from_config_rejects_non_dict_config(invalid_config):
"""Test from_config reports invalid top-level config values cleanly."""
with pytest.raises(ValueError, match="not a valid processor configuration"):
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
+5
View File
@@ -59,6 +59,7 @@ def test_strategy_config_types():
from lerobot.rollout import (
BaseStrategyConfig,
DAggerStrategyConfig,
EpisodicStrategyConfig,
HighlightStrategyConfig,
SentryStrategyConfig,
)
@@ -67,6 +68,7 @@ def test_strategy_config_types():
assert SentryStrategyConfig().type == "sentry"
assert HighlightStrategyConfig().type == "highlight"
assert DAggerStrategyConfig().type == "dagger"
assert EpisodicStrategyConfig().type == "episodic"
def test_dagger_config_invalid_input_device():
@@ -203,6 +205,8 @@ def test_create_strategy_dispatches():
BaseStrategyConfig,
DAggerStrategy,
DAggerStrategyConfig,
EpisodicStrategy,
EpisodicStrategyConfig,
SentryStrategy,
SentryStrategyConfig,
create_strategy,
@@ -211,6 +215,7 @@ def test_create_strategy_dispatches():
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
assert isinstance(create_strategy(EpisodicStrategyConfig()), EpisodicStrategy)
def test_create_strategy_unknown_raises():
Generated
+49 -5
View File
@@ -1464,6 +1464,17 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" },
]
[[package]]
name = "flash-attn"
version = "2.8.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "einops", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" },
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" }
[[package]]
name = "flatbuffers"
version = "25.12.19"
@@ -2679,10 +2690,8 @@ all = [
{ name = "contourpy" },
{ name = "datasets" },
{ name = "debugpy" },
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
{ name = "deepdiff" },
{ name = "diffusers" },
{ name = "dm-tree" },
{ name = "dynamixel-sdk" },
{ name = "faker" },
{ name = "fastapi" },
@@ -2730,7 +2739,6 @@ all = [
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "teleop" },
{ name = "timm" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "torchdiffeq" },
{ name = "transformers" },
@@ -2822,6 +2830,10 @@ eo1 = [
evaluation = [
{ name = "av" },
]
fastwam = [
{ name = "diffusers" },
{ name = "transformers" },
]
feetech = [
{ name = "deepdiff" },
{ name = "feetech-servo-sdk" },
@@ -2835,6 +2847,8 @@ groot = [
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
{ name = "diffusers" },
{ name = "dm-tree" },
{ name = "flash-attn", marker = "sys_platform != 'darwin'" },
{ name = "ninja" },
{ name = "peft" },
{ name = "timm" },
{ name = "transformers" },
@@ -3077,6 +3091,7 @@ requires-dist = [
{ name = "faker", marker = "extra == 'sarm'", specifier = ">=33.0.0,<35.0.0" },
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
@@ -3112,16 +3127,17 @@ requires-dist = [
{ name = "lerobot", extras = ["deepdiff-dep"], marker = "extra == 'hardware'" },
{ name = "lerobot", extras = ["dev"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'fastwam'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["fastwam"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["groot"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
@@ -3186,6 +3202,7 @@ requires-dist = [
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'fastwam'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
@@ -3214,6 +3231,7 @@ requires-dist = [
{ name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" },
{ name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" },
{ name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" },
{ name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" },
{ name = "numpy", specifier = ">=2.0.0,<2.3.0" },
{ name = "onnx", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" },
@@ -3265,7 +3283,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "fastwam", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"
@@ -4001,6 +4019,32 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
]
[[package]]
name = "ninja"
version = "1.13.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" },
{ url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" },
{ url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" },
{ url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" },
{ url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" },
{ url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" },
{ url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" },
{ url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" },
{ url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" },
{ url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" },
{ url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" },
{ url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" },
{ url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" },
{ url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" },
{ url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" },
{ url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" },
{ url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" },
{ url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" },
]
[[package]]
name = "nodeenv"
version = "1.10.0"