Compare commits

..

1 Commits

Author SHA1 Message Date
Khalil Meftah 17e217d175 refactor(policies): clean MolmoAct2 to follow EO1/TOPReward patterns
Align the MolmoAct2 implementation with lerobot codebase conventions:

- Rename hf_model/ to molmoact2_hf_model/
- Slim config: move all I/O and runtime logic to modeling
- Remove blanket  from 8 vendored files, fix 66 lint issues
- Deduplicate _hf_token() and _resolve_checkpoint_location()
- Make huggingface_hub imports lazy
- Remove custom MolmoAct2CosineDecayWithWarmupSchedulerConfig, use base class
- Extract 13 static/classmethods from MolmoAct2Policy to free functions
- Replace print() with logger in vendored action_tokenizer
- Add module docstrings, class docstring, and key method docstrings
- Add module-level loggers to modeling and processor
- Fix docs: pip to uv install, deduplicate README symlink
- Remove shebangs from all files
2026-06-05 16:31:03 +02:00
41 changed files with 4029 additions and 7239 deletions
+1 -1
View File
@@ -105,7 +105,7 @@ lerobot-train \
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+1 -1
View File
@@ -68,7 +68,7 @@
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
- local: multi_task_dit
+1 -1
View File
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
- [SmolVLA](./smolvla)
- [Pi0.5](./pi05)
- [GR00T N1.7](./groot)
- [GR00T N1.5](./groot)
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
+30 -76
View File
@@ -1,16 +1,16 @@
# GR00T Policy
# GR00T N1.5 Policy
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
LeRobot integrates GR00T N1.7 through the `groot` policy type.
This document outlines the specifics of its integration and usage within the LeRobot framework.
## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
@@ -28,46 +28,33 @@ This approach allows the model to be highly adaptable through post-training for
## Installation Requirements
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
As of today, GR00T N1.5 requires flash attention for it's internal working.
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
```
3. Install and verify Flash Attention:
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
4. Install LeRobot with the GR00T extra:
3. Install LeRobot by running:
```bash
pip install "lerobot[groot]"
pip install lerobot[groot]
```
For a source checkout, use the same order, then install the local package with:
```bash
pip install -e ".[groot]"
```
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
## Usage
To use GR00T N1.7:
To use GR00T in your LeRobot configuration, specify the policy type as:
```bash
--policy.type=groot \
--policy.model_version=n1.7
```python
policy.type=groot
```
## Training
@@ -100,54 +87,21 @@ accelerate launch \
## Performance Results
### LIBERO Benchmark Results
### Libero Benchmark Results
> [!NOTE]
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
> Follow our instructions for Libero usage: [Libero](./libero)
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
### GR00T N1.7 LIBERO Checkpoints
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
| Suite | Checkpoint subdirectory |
| -------------- | ----------------------- |
| LIBERO Spatial | `libero_spatial` |
| LIBERO Object | `libero_object` |
| LIBERO Goal | `libero_goal` |
| LIBERO 10 | `libero_10` |
Preliminary LeRobot integration results:
| Suite | Status | Success rate | n_episodes |
| -------------- | ------ | -----------: | ---------: |
| LIBERO Spatial | ✓ | ~95% | XX |
| LIBERO Object | ✓ | XX% | XX |
| LIBERO Goal | ✓ | XX% | XX |
| LIBERO 10 | ✓ | XX% | XX |
| **Average** | ✓ | **XX%** | **XX** |
Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
lerobot-eval \
--policy.type=groot \
--policy.model_version=n1.7 \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \
--eval.n_episodes=50
```
Use `eval.n_episodes >= 50` per suite when reporting success rates.
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
### Evaluate in your hardware setup
@@ -177,4 +131,4 @@ lerobot-rollout\
## License
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
+3 -3
View File
@@ -17,7 +17,7 @@ the paper, see [allenai/molmoact2](https://github.com/allenai/molmoact2).
Install LeRobot with the MolmoAct2 optional dependencies:
```bash
pip install -e ".[molmoact2]"
uv sync --locked --extra molmoact2
```
To run the models in this repository, you need an NVIDIA GPU. The measurements
@@ -46,8 +46,8 @@ The repo has been tested with Ubuntu 22.04.
To use MolmoAct2 in a LeRobot training config, set:
```python
policy.type=molmoact2
```bash
--policy.type=molmoact2
```
## Training
+1 -78
View File
@@ -24,81 +24,4 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Models:
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
byte-identical pre-processed inputs and the same flow-matching seed. It is
parametrized over every embodiment tag present in the checkpoint.
### Why two environments
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
is itself a defaulted dataclass, so the original config dataclasses fail to import
(`non-default argument follows default argument`). The two implementations therefore
**cannot be imported in the same Python process**.
So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves the exact
collated inputs + raw `action_pred` to one `.npz` per tag.
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
seed, and asserts the outputs match.
### Fairness controls
- **Same pre-processed inputs** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
### How to run
```bash
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
CKPT=$(python - <<'PY'
import os
from huggingface_hub import snapshot_download
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
allow_patterns=["libero_10/*"]), "libero_10"))
PY
)
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
tests/policies/groot/utils/dump_original_n1_7.py \
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
```
The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by
the producer; they are never committed. The test **skips** (does not fail) on CI or
when the checkpoint / artifacts are absent.
#### Env knobs (all optional)
| Var | Default | Purpose |
|---|---|---|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
+18 -14
View File
@@ -280,22 +280,26 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
return make_groot_pre_post_processors_from_pretrained(
config=policy_cfg,
pretrained_path=pretrained_path,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
preprocessor_config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
postprocessor_config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
)
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
+1 -9
View File
@@ -18,12 +18,4 @@ from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
def __getattr__(name: str):
if name in {"GR00TN17", "GR00TN17Config"}:
from .groot_n1_7 import GR00TN17, GR00TN17Config
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
b, t = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
@@ -181,7 +181,8 @@ class BasicTransformerBlock(nn.Module):
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
@@ -317,71 +318,6 @@ class DiT(ModelMixin, ConfigMixin):
return self.proj_out_2(hidden_states)
class AlternateVLDiT(DiT):
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
super().__init__(*args, **kwargs)
self.attend_text_every_n_blocks = attend_text_every_n_blocks
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
image_mask: torch.Tensor | None = None,
backbone_attention_mask: torch.Tensor | None = None,
):
if image_mask is None:
raise ValueError("image_mask is required for AlternateVLDiT.")
if backbone_attention_mask is None:
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
image_attention_mask = image_mask & backbone_attention_mask
non_image_attention_mask = (~image_mask) & backbone_attention_mask
all_hidden_states = [hidden_states]
if not self.config.interleave_self_attention:
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
curr_encoder_attention_mask = (
non_image_attention_mask
if idx % (2 * self.attend_text_every_n_blocks) == 0
else image_attention_mask
)
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=curr_encoder_attention_mask,
temb=temb,
)
all_hidden_states.append(hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@@ -0,0 +1,408 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import field
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
else:
PretrainedConfig = object
BatchFeature = None
from .action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from .cross_attention_dit import DiT, SelfAttentionTransformer
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
b, t, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == b:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, t)
else:
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
backbone_embedding_dim: int = field(
default=1536, metadata={"help": "Backbone embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class FlowmatchingActionHead(nn.Module):
config_class = FlowmatchingActionHeadConfig
supports_gradient_checkpointing = True
def __init__(
self,
config: FlowmatchingActionHeadConfig,
):
super().__init__()
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
self.model = DiT(**config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
self.vl_self_attention = (
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
for p in self.parameters():
p.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
print(f"Tune action head projector: {self.tune_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_projector and not tune_diffusion_model:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Action head trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
def sample_time(self, batch_size, device, dtype):
if self._beta_dist is None:
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = backbone_output["backbone_features"]
backbone_features = self.vlln(backbone_features)
backbone_features = self.vl_self_attention(backbone_features)
backbone_output["backbone_features"] = backbone_features
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
# Set frozen modules to eval
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
if self.config.expand_batch is not None:
for k, v in backbone_output.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
backbone_output[k] = expanded
for k, v in action_input.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
action_input[k] = expanded
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
device = vl_embs.device
# Get embodiment ID.
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Embed noised action trajectory.
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
vl_attn_mask = backbone_output.backbone_attention_mask
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=vl_attn_mask,
timestep=t_discretized,
return_all_hidden_states=False, # NOTE (YL): not using flare now
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
action_mask = action_input.action_mask
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = loss.sum() / action_mask.sum()
output_dict = {
"loss": loss,
}
return BatchFeature(data=output_dict)
@torch.no_grad()
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Set initial actions as the sampled noise.
batch_size = vl_embs.shape[0]
device = vl_embs.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# Embed noised action trajectory.
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
# Run model forward.
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
# Update actions using euler integration.
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype
@@ -14,294 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
"n1_7": GROOT_N1_7,
"n1d7": GROOT_N1_7,
"n17": GROOT_N1_7,
"1.7": GROOT_N1_7,
}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
}
def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
return normalized
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
if transform is None:
return None
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
supported = ", ".join(
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
)
raise ValueError(
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
f"Supported transforms: none, {supported}."
)
return normalized
def infer_groot_model_version(model_path: str | None) -> str | None:
if not model_path:
return None
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
# N1.5 is unsupported, but it must still be recognised here to fail loudly
# rather than silently treating an N1.5 checkpoint as N1.7.
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
return None
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
if model_path is None:
return False
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return False
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return False
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if "libero_sim" in modality_configs:
return "libero_sim"
if len(modality_configs) == 1:
return next(iter(modality_configs))
return None
def infer_groot_n1_7_action_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
return None
modality_configs = processor_kwargs.get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag is None:
return None
embodiment_config = modality_configs.get(embodiment_tag, {})
if not isinstance(embodiment_config, dict):
return None
action_config = embodiment_config.get("action", {})
if not isinstance(action_config, dict):
return None
delta_indices = action_config.get("delta_indices", [])
if not isinstance(delta_indices, list):
return None
return len(delta_indices) or None
def infer_groot_n1_7_action_execution_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
if action_horizon is None:
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag == "libero_sim":
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
# actions. Keeping that execution cadence avoids stale open-loop chunks.
return min(action_horizon, 8)
return action_horizon
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
model_path = Path(model_name).expanduser()
if model_path.exists():
return str(model_path)
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
return str(cached_snapshot) if cached_snapshot is not None else model_name
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
required_files = (
"config.json",
"tokenizer_config.json",
"preprocessor_config.json",
"video_preprocessor_config.json",
)
for hub_cache in _candidate_hf_hub_caches(cache_dir):
repo_cache = hub_cache / repo_cache_name
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.is_dir():
continue
candidates: list[Path] = []
ref_path = repo_cache / "refs" / "main"
try:
ref = ref_path.read_text().strip()
except OSError:
ref = ""
if ref:
candidates.append(snapshots_dir / ref)
candidates.extend(
sorted(
(path for path in snapshots_dir.iterdir() if path.is_dir()),
key=lambda path: path.stat().st_mtime,
reverse=True,
)
)
seen: set[Path] = set()
for snapshot in candidates:
if snapshot in seen:
continue
seen.add(snapshot)
if all((snapshot / filename).exists() for filename in required_files):
return snapshot
return None
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
candidates: list[Path] = []
if cache_dir is not None:
cache_path = Path(cache_dir).expanduser()
candidates.append(cache_path)
candidates.append(cache_path / "hub")
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
if hub_cache:
candidates.append(Path(hub_cache).expanduser())
hf_home = os.environ.get("HF_HOME")
if hf_home:
candidates.append(Path(hf_home).expanduser() / "hub")
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
deduped: list[Path] = []
seen: set[Path] = set()
for candidate in candidates:
resolved = candidate.resolve() if candidate.exists() else candidate
if resolved not in seen:
seen.add(resolved)
deduped.append(candidate)
return deduped
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return None
if not config_path.exists():
return None
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
return _infer_groot_model_version_from_config(config)
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
try:
return normalize_groot_model_version(model_version)
except ValueError:
return None
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
for candidate in candidates:
if not isinstance(candidate, str):
continue
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
return None
@PreTrainedConfig.register_subclass("groot")
@dataclass
@@ -334,17 +52,11 @@ class GrootConfig(PreTrainedConfig):
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
# Path or HuggingFace model ID for the base Groot model
base_model_path: str | None = None
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -405,38 +117,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = GROOT_N1_7_BASE_MODEL
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
# wrapper applies this conversion; mirror it here so eval on the
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
raise ValueError(
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
super().__post_init__()
if self.n_action_steps > self.chunk_size:
@@ -512,8 +192,7 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
return list(range(min(self.chunk_size, model_action_horizon)))
return list(range(min(self.chunk_size, 16)))
@property
def reward_delta_indices(self) -> None:
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Eagle25VLConfig(PretrainedConfig):
model_type = "eagle_2_5_vl"
is_composition = True
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
def __init__(
self,
vision_config=None,
text_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-4,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
loss_version="v1",
min_dynamic_tiles=1,
max_dynamic_tiles=6,
mlp_checkpoint=False,
initializer_range=0.02,
_attn_implementation="flash_attention_2",
_attn_implementation_autoset=False,
llm_config=None,
image_token_index=None,
use_pixel_shuffle=True,
mlp_connector_layers=2,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"model_type": "siglip_vision_model"}
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
if text_config is None:
text_config = {"architectures": ["Qwen2ForCausalLM"]}
logger.info(
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
if text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = LlamaConfig(**text_config)
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
self.text_config = Qwen2Config(**text_config)
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
self.text_config = Qwen3Config(**text_config)
else:
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.mlp_checkpoint = mlp_checkpoint
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.loss_version = loss_version
self.initializer_range = initializer_range
self.min_dynamic_tiles = min_dynamic_tiles
self.max_dynamic_tiles = max_dynamic_tiles
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self._attn_implementation = _attn_implementation
self._attn_implementation_autoset = _attn_implementation_autoset
self.image_token_index = image_token_index
self.use_pixel_shuffle = use_pixel_shuffle
self.mlp_connector_layers = mlp_connector_layers
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["pad2square"] = self.pad2square
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["min_dynamic_tiles"] = self.min_dynamic_tiles
output["max_dynamic_tiles"] = self.max_dynamic_tiles
output["tie_word_embeddings"] = self.tie_word_embeddings
output["_attn_implementation"] = self._attn_implementation
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
output["use_pixel_shuffle"] = self.use_pixel_shuffle
output["mlp_connector_layers"] = self.mlp_connector_layers
return output
@@ -0,0 +1,503 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from __future__ import annotations
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
ImagesKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
validate_kwargs,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_v2_available,
)
from transformers.video_utils import VideoInput
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F # noqa: N812
from transformers.image_utils import pil_torch_interpolation_mapping
else:
from torchvision.transforms import functional as F # noqa: N812
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
"""Crop the given numpy array.
Args:
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
left (int): The left coordinate of the crop box.
top (int): The top coordinate of the crop box.
right (int): The right coordinate of the crop box.
bottom (int): The bottom coordinate of the crop box.
Returns:
torch.Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
if img.ndim not in [2, 3]:
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
img_height = img.shape[1]
img_width = img.shape[2]
if top < 0 or left < 0 or bottom > img_height or right > img_width:
raise ValueError("Crop coordinates out of bounds")
if top >= bottom or left >= right:
raise ValueError("Invalid crop coordinates")
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
pad_during_tiling: bool | None
do_pad: bool | None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
max_dynamic_tiles = 12
min_dynamic_tiles = 1
use_thumbnail = True
pad_during_tiling = False
valid_kwargs = Eagle25VLFastImageProcessorKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
"""
max_dynamic_tiles (`int`, *optional*):
The maximum number of dynamic tiles to use for processing high resolution images.
min_dynamic_tiles (`int`, *optional*):
The minimum number of dynamic tiles to use for processing high resolution images.
use_thumbnail (`bool`, *optional*):
Whether to use a thumbnail for processing high resolution images.
pad_during_tiling (`bool`, *optional*):
Whether to pad the image during tiling.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
# NOTE(YL): we will overload the preprocess method to add the image_flags
# def preprocess(
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
# ) -> BatchFeature:
# return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: torch.Tensor,
target_resolution: tuple,
interpolation: F.InterpolationMode,
input_data_format: ChannelDimension,
) -> torch.Tensor:
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly focus on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def _pad_for_patching(
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
) -> torch.Tensor:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: torch.Tensor,
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: F.InterpolationMode,
pad_during_tiling: bool,
) -> list[torch.Tensor]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
# calculate the target width and height
target_width = tile_size * target_aspect_ratio[0]
target_height = tile_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if pad_during_tiling:
resized_image = self._resize_for_patching(
image,
(target_height, target_width),
interpolation=interpolation,
input_data_format=ChannelDimension.FIRST,
)
padded_image = self._pad_for_patching(
resized_image,
(target_height, target_width),
input_data_format=ChannelDimension.FIRST,
)
image_used_to_split = padded_image
else:
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
processed_tiles = []
for i in range(blocks):
box = (
(i % (target_width // tile_size)) * tile_size,
(i // (target_width // tile_size)) * tile_size,
((i % (target_width // tile_size)) + 1) * tile_size,
((i // (target_width // tile_size)) + 1) * tile_size,
)
# split the image
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
processed_tiles.append(split_img)
assert len(processed_tiles) == blocks
if use_thumbnail and len(processed_tiles) != 1:
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
processed_tiles.append(thumbnail_img)
return processed_tiles
def _pad_for_batching(
self,
pixel_values: list[torch.Tensor],
) -> list[torch.Tensor]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: F.InterpolationMode | None,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
do_pad: bool,
return_tensors: str | TensorType | None,
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
tile_size = crop_size.height
elif size and size.height:
tile_size = size.height
else:
tile_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
min_num=min_dynamic_tiles,
max_num=max_dynamic_tiles,
size=size_tuple,
tile_size=tile_size,
use_thumbnail=use_thumbnail,
interpolation=interpolation,
pad_during_tiling=pad_during_tiling,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
# Added for transformers >=4.53.0 compatibility
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches,
disable_grouping=disable_grouping,
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(
processed_image_patches_grouped, grouped_image_patches_index
)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes},
tensor_type=return_tensors,
)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
if images is not None:
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
if videos is not None:
videos = self._prepare_image_like_inputs(
images=videos,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
# Added for transformers >=4.53.0 compatibility
resample = kwargs.pop("resample", self.resample)
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, PILImageResampling | int)
else resample
)
# Filter kwargs to only include those accepted by _preprocess
valid_preprocess_kwargs = {
"do_resize",
"size",
"max_dynamic_tiles",
"min_dynamic_tiles",
"use_thumbnail",
"pad_during_tiling",
"interpolation",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"pad_size",
"disable_grouping",
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
if images is not None:
return self._preprocess(images, **filtered_kwargs)
elif videos is not None:
return self._preprocess(videos, **filtered_kwargs)
__all__ = ["Eagle25VLImageProcessorFast"]
@@ -0,0 +1,396 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import inspect
import torch
import torch.utils.checkpoint as cp
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.utils import add_start_docstrings, logging
from .configuration_eagle2_5_vl import Eagle25VLConfig
logger = logging.get_logger(__name__)
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
EAGLE2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Eagle25VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
EAGLE2_5_VL_START_DOCSTRING,
)
class Eagle25VLPreTrainedModel(PreTrainedModel):
config_class = Eagle25VLConfig
base_model_prefix = "model"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = [
"Qwen2DecoderLayer",
"LlamaDecoderLayer",
"Siglip2EncoderLayer",
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear | nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
config_class = Eagle25VLConfig
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
if config.use_pixel_shuffle:
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
else:
self.num_image_token = int((image_size // patch_size) ** 2)
self.select_layer = config.select_layer
self.downsample_ratio = config.downsample_ratio
self.loss_version = config.loss_version
self.mlp_checkpoint = config.mlp_checkpoint
self.use_pixel_shuffle = config.use_pixel_shuffle
self.mlp_connector_layers = config.mlp_connector_layers
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
config.vision_config._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
if language_model is not None:
self.language_model = language_model
else:
if config.text_config.architectures[0] == "LlamaForCausalLM":
self.language_model = LlamaForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
assert config.text_config._attn_implementation == "flash_attention_2", (
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
)
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
else:
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
if config.mlp_connector_layers == 2:
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size, llm_hidden_size),
)
else:
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
self.image_token_index = config.image_token_index
self.neftune_alpha = None
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
self.use_llm_lora = config.use_llm_lora
if config.use_llm_lora:
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
self.check_forward_kwargs()
def check_forward_kwargs(self):
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
# has special handling for functions with **kwargs parameters that would affect
# how our model is processed during training and inference.
forward_params = inspect.signature(self.forward).parameters
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.out_proj",
"mlp.fc1",
"mlp.fc2",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM",
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
self.use_llm_lora = True
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_flags: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
num_tiles_list: list[torch.Tensor] | None = None,
) -> tuple | CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
if image_flags is not None:
image_flags = image_flags.view(-1)
vit_embeds = vit_embeds[image_flags == 1]
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.image_token_index
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, c)
print(
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}"
)
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(b, n, c)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
)
if hasattr(vit_embeds, "last_hidden_state"):
vit_embeds = vit_embeds.last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
if self.use_pixel_shuffle:
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
if self.mlp_checkpoint and vit_embeds.requires_grad:
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
else:
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor | None = None,
input_ids: torch.FloatTensor | None = None,
attention_mask: torch.LongTensor | None = None,
visual_features: torch.FloatTensor | None = None,
generation_config: GenerationConfig | None = None,
output_hidden_states: bool | None = None,
image_sizes: list[tuple[int, int]] | None = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.config.image_token_index
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
input_embeds = input_embeds.reshape(b, n, c)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
if "use_cache" not in generate_kwargs:
generate_kwargs["use_cache"] = True
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
**generate_kwargs,
)
return outputs
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
@@ -0,0 +1,541 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Eagle25VL.
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
"""
import base64
import os
import re
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 256
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == "RGBA":
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
image = ele["image"] if "image" in ele else ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True, timeout=10)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = to_rgb(image_obj)
if "scale_factor" in ele:
scale_factor = ele["scale_factor"]
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
return image
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"videos_kwargs": {"max_dynamic_tiles": 1},
}
class Eagle25VLProcessor(ProcessorMixin):
r"""
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
Args:
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
num_image_tokens (`int`, *optional*):
Number of image tokens for one imagethat will be returned by vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"num_image_tokens",
"vision_feature_select_strategy",
"image_token",
"video_token",
"images_kwargs",
"videos_kwargs",
"text_kwargs",
]
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<IMG_CONTEXT>", # nosec: B107
video_token="<IMG_CONTEXT>", # nosec: B107
tokens_per_tile=256,
image_placeholder="image",
video_placeholder="video",
image_start_token="<img>",
image_end_token="</img>",
**kwargs,
):
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.image_placeholder = image_placeholder
self.video_placeholder = video_placeholder
self.tokens_per_tile = tokens_per_tile
self.image_start_token = image_start_token
self.image_end_token = image_end_token
if "auto_map" in kwargs:
self.auto_map = kwargs["auto_map"]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def replace_media_placeholder(
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
):
num_of_images_in_this_sample = 0
num_of_videos_in_this_sample = 0
# Regular expression pattern to match formats like <image-1> or <video-2>
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
unified_frame_list = []
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
# )
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
# )
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
# "use_thumbnail", self.image_processor.use_thumbnail
# )
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
)
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
)
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
"use_thumbnail", self.image_processor.use_thumbnail
)
tile_size = self.image_processor.size.get("height", 448)
# Function to replace tags in a single text
def replace_in_text(text):
# repl callback function for each match replacement operation
def repl(match):
nonlocal unified_frame_list
nonlocal num_of_images_in_this_sample
nonlocal num_of_videos_in_this_sample
media_type = match.group(1) # 'image' or 'video'
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
# Select the corresponding path based on media type
idx_mapper = {
0: "first",
1: "second",
2: "third",
3: "fourth",
4: "fifth",
5: "sixth",
6: "seventh",
7: "eighth",
8: "ninth",
9: "tenth",
}
if media_type == "image":
image_inputs = self.image_processor(
images=[image_list[idx_in_list]],
videos=None,
**output_kwargs["images_kwargs"],
)
if isinstance(image_inputs["pixel_values"], list):
_pv = image_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
image_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
num_of_images_in_this_sample += 1
elif media_type == "video":
video_inputs = self.image_processor(
images=None,
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
if isinstance(video_inputs["pixel_values"], list):
_pv = video_inputs["pixel_values"]
if _pv and isinstance(_pv[0], list):
_pv = [t for sub in _pv for t in sub]
video_inputs["pixel_values"] = torch.stack(
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
frame_timestamps = timestamps_list[idx_in_list]
else:
frame_timestamps = None
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
num_of_tiles_each_frame = [
self.get_number_tiles_based_on_image_size(
image_size,
video_min_dynamic_tiles,
video_max_dynamic_tiles,
video_use_thumbnail,
tile_size,
)
for image_size in image_sizes
]
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
)
if frame_timestamps is not None:
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
)
special_placeholder = [
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
else:
special_placeholder = [
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
if sampled_fps is not None:
special_placeholder = (
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
+ "".join(special_placeholder)
)
else:
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
special_placeholder
)
unified_frame_list.append(video_inputs)
num_of_videos_in_this_sample += 1
else:
raise ValueError(f"Unknown media type: {media_type}")
return special_placeholder
return pattern.sub(repl, text)
text = replace_in_text(text)
if len(unified_frame_list) > 0:
def _to_tensor(v):
if isinstance(v, torch.Tensor):
return v
if isinstance(v, list):
if v and isinstance(v[0], list):
v = [t for sub in v for t in sub]
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
return torch.as_tensor(v)
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
return (
text,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Eagle25VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text_list = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
elif isinstance(text, list) and isinstance(text[0], str):
text_list = text
if images is None:
images = []
if videos is None:
videos = []
pixel_values_list = []
image_sizes_list = []
new_sample_list = []
image_start_idx = 0
video_start_idx = 0
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
for sample in text_list:
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
(
sample,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
) = self.replace_media_placeholder(
sample,
images[image_start_idx:],
videos[video_start_idx:],
timestamps_list,
fps_list,
**output_kwargs,
)
new_sample_list.append(sample)
if pixel_values is not None:
pixel_values_list.append(pixel_values)
image_sizes_list.append(image_sizes)
image_start_idx += num_of_images_in_this_sample
video_start_idx += num_of_videos_in_this_sample
if len(pixel_values_list) > 0:
image_inputs = {
"pixel_values": torch.cat(pixel_values_list),
"image_sizes": torch.cat(image_sizes_list),
}
else:
image_inputs = {}
video_inputs = {}
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
def get_number_tiles_based_on_image_size(
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
) -> int:
"""
Get the number of tiles based on the image size.
"""
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and tiles_num > 1:
tiles_num += 1
return tiles_num
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# override to save video-config in a separate config file
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
outputs = super().save_pretrained(save_directory, **kwargs)
return outputs
# override to load video-config from a separate config file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
return processor
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def process_vision_info(
self,
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
vision_infos = self.extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
video_timestamps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return (
image_inputs,
video_inputs,
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
)
return image_inputs, video_inputs
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
__all__ = ["Eagle25VLProcessor"]
+380
View File
@@ -0,0 +1,380 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
from .action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from .utils import ensure_eagle_cache_ready
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if ACTION in inputs:
action = inputs[ACTION]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model
-962
View File
@@ -1,962 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import json
import logging
from contextlib import suppress
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
try:
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
except ImportError:
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
logger = logging.getLogger(__name__)
def _copy_default(value: Any) -> Any:
return deepcopy(value)
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
"model_name": "nvidia/Cosmos-Reason2-2B",
"backbone_model_type": "qwen",
"model_revision": None,
"tune_top_llm_layers": 0,
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 12,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
"color_jitter_params": None,
"use_albumentations_transforms": True,
"extra_augmentation_config": None,
"formalize_language": True,
"apply_sincos_state_encoding": False,
"use_percentiles": True,
"use_relative_action": False,
"max_state_dim": 132,
"max_action_dim": 132,
"action_horizon": 40,
"hidden_size": 1024,
"input_embedding_dim": 1536,
"state_history_length": 1,
"add_pos_embed": True,
"attn_dropout": 0.2,
"use_vlln": True,
"max_seq_len": 1024,
"use_alternate_vl_dit": True,
"attend_text_every_n_blocks": 2,
"diffusion_model_cfg": {
"positional_embeddings": None,
"num_layers": 32,
"num_attention_heads": 32,
"attention_head_dim": 48,
"norm_type": "ada_norm",
"dropout": 0.2,
"final_dropout": True,
"output_dim": 1024,
"interleave_self_attention": True,
},
"vl_self_attention_cfg": {
"positional_embeddings": None,
"num_layers": 4,
"num_attention_heads": 32,
"attention_head_dim": 64,
"dropout": 0.2,
"final_dropout": True,
},
"num_inference_timesteps": 4,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
"tune_projector": True,
"tune_diffusion_model": True,
"tune_vlln": True,
"state_dropout_prob": 0.2,
"exclude_state": False,
"use_mean_std": False,
"max_num_embodiments": 32,
"rtc_ramp_rate": 6.0,
}
class GR00TN17Config(PretrainedConfig):
"""Configuration for NVIDIA GR00T N1.7.
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
flow-matching action head. This mirrors the public N1.7 checkpoint config
while keeping it local to LeRobot and independent from the external
Isaac-GR00T ``gr00t`` Python package.
"""
model_type = "Gr00tN1d7"
_defaults = GR00T_N1_7_DEFAULTS
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in GR00T_N1_7_DEFAULTS.items():
setattr(self, key, _copy_default(kwargs.pop(key, value)))
for key, value in kwargs.items():
setattr(self, key, value)
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
cfg = self.to_dict()
if not exclude_augment:
return cfg
exclude_keys = {
"random_rotation_angle",
"color_jitter_params",
"use_albumentations_transforms",
"formalize_language",
"image_crop_size",
"image_target_size",
"shortest_image_edge",
"crop_fraction",
}
return {k: v for k, v in cfg.items() if k not in exclude_keys}
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
class CategorySpecificLinear(nn.Module):
"""Linear layer with category-specific weights for multi-embodiment support."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
super().__init__()
self.num_categories = num_categories
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
"""Two-layer MLP with category-specific weights."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
The frequency scalar is intentionally created on CPU and then broadcast with
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
embedding and avoids tiny dtype/device construction differences in parity
tests.
"""
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class MultiEmbodimentActionEncoder(nn.Module):
"""Action encoder with category-specific projections and sinusoidal time encoding."""
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
super().__init__()
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
batch_size, horizon, _ = actions.shape
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
raise ValueError("Expected `timesteps` to have shape (B,).")
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
action_emb = self.W1(actions, cat_ids)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
return self.W3(x, cat_ids)
class Qwen3Backbone(nn.Module):
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
The public checkpoint stores the action head in the GR00T checkpoint but
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
keeps the nested HF module layout compatible across transformer versions
and exposes the hidden states consumed by the action head.
"""
def __init__(
self,
model_name: str = "nvidia/Cosmos-Reason2-2B",
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
tune_top_llm_layers: int = 0,
trainable_params_fp32: bool = False,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_pretrained_weights: bool = True,
):
if Qwen3VLForConditionalGeneration is None:
raise ImportError(
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
"or use a transformers version that provides Qwen3-VL support."
)
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
extra_kwargs: dict[str, Any] = {}
if use_flash_attention:
try:
import flash_attn # noqa: F401
extra_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
extra_kwargs["attn_implementation"] = "sdpa"
if load_bf16:
extra_kwargs["torch_dtype"] = torch.bfloat16
if load_pretrained_weights:
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
**extra_kwargs,
**transformers_loading_kwargs,
).eval()
else:
self.model = self._from_backbone_config(
model_name=model_name,
model_kwargs=extra_kwargs,
config_kwargs=transformers_loading_kwargs,
).eval()
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
if load_bf16 and trainable_params_fp32:
for parameter in self.parameters():
if parameter.requires_grad:
parameter.data = parameter.data.to(torch.float32)
def set_trainable_parameters(
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
) -> None:
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_llm:
self.language_model.requires_grad_(False)
if not tune_visual:
self.visual.requires_grad_(False)
if tune_top_llm_layers > 0:
for layer in self.language_model.layers[-tune_top_llm_layers:]:
for parameter in layer.parameters():
parameter.requires_grad = True
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if self.language_model and not self.tune_llm:
self.language_model.eval()
if self.visual and not self.tune_visual:
self.visual.eval()
@property
def language_model(self) -> nn.Module:
return getattr(self.model, "model", self.model).language_model
@property
def visual(self) -> nn.Module:
return getattr(self.model, "model", self.model).visual
def _from_backbone_config(
self,
*,
model_name: str,
model_kwargs: dict[str, Any],
config_kwargs: dict[str, Any],
) -> nn.Module:
if _is_cosmos_reason2_backbone(model_name):
backbone_config = _cosmos_reason2_qwen3_vl_config()
else:
if AutoConfig is None:
raise ImportError(
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
if "mm_token_type_ids" in model_input:
return
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
return
input_ids = model_input.get("input_ids")
if input_ids is None:
return
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None:
mm_token_type_ids[input_ids == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[input_ids == video_token_id] = 2
model_input["mm_token_type_ids"] = mm_token_type_ids
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
drops text position ids before calling text-layer flash attention. GR00T
N1.7 was aligned against the older Transformers path, where a fourth text
position row is forwarded alongside the temporal/height/width rows. Adding
the row here preserves the newer multimodal position computation while
keeping flash attention on the legacy code path.
"""
if "position_ids" in model_input:
return
qwen3_model = getattr(self.model, "model", self.model)
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
if compute_3d_position_ids is None:
return
position_ids = compute_3d_position_ids(
input_ids=model_input.get("input_ids"),
image_grid_thw=model_input.get("image_grid_thw"),
video_grid_thw=model_input.get("video_grid_thw"),
inputs_embeds=None,
attention_mask=model_input.get("attention_mask"),
past_key_values=None,
mm_token_type_ids=model_input.get("mm_token_type_ids"),
)
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
model_input["position_ids"] = position_ids
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
Newer releases expose the post-final-norm tensor there instead. Capturing
the last decoder layer output directly keeps the N1.7 action head input
stable across Transformers versions.
"""
captured: dict[str, torch.Tensor] = {}
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
if isinstance(output, torch.Tensor):
captured["features"] = output
elif isinstance(output, (tuple, list)) and output:
captured["features"] = output[0]
elif hasattr(output, "last_hidden_state"):
captured["features"] = output.last_hidden_state
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
try:
outputs = self.model(**model_input, output_hidden_states=True)
finally:
hook.remove()
return captured.get("features", outputs.hidden_states[-1])
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
model_input = {key: vl_input[key] for key in keys_to_use}
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
self._ensure_mm_token_type_ids(model_input)
self._ensure_legacy_qwen3_position_ids(model_input)
features = self._last_decoder_layer_output(model_input)
image_mask = model_input["input_ids"] == self.model.config.image_token_id
attention_mask = model_input["attention_mask"] == 1
return BatchFeature(
data={
"backbone_features": features,
"backbone_attention_mask": attention_mask,
"image_mask": image_mask,
}
)
class GR00TN17ActionHead(nn.Module):
supports_gradient_checkpointing = True
def __init__(self, config: GR00TN17Config):
require_package("diffusers", extra="groot")
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
if config.use_alternate_vl_dit:
self.model = AlternateVLDiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
)
else:
self.model = DiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
)
self.action_dim = config.max_action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim * config.state_history_length,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
else:
self.vl_self_attention = nn.Identity()
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.state_dropout_prob = config.state_dropout_prob
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
def set_trainable_parameters(
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
) -> None:
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vlln = tune_vlln
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vlln:
self.vlln.requires_grad_(False)
self.vl_self_attention.requires_grad_(False)
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vlln:
self.vlln.eval()
self.vl_self_attention.eval()
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if self._beta_dist is None:
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = self.vlln(backbone_output["backbone_features"])
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
vl_embeds = backbone_output.backbone_features
device = vl_embeds.device
embodiment_id = action_input.embodiment_id
if action_input.state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = action_input.state.view(action_input.state.shape[0], 1, -1)
state_features = self.state_encoder(state, embodiment_id)
if self.training and self.state_dropout_prob > 0:
do_dropout = (
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
)
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None]
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
data={
"loss": loss,
"action_loss": action_loss,
"action_mask": action_mask,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
state = action_input.state
if state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = state.view(state.shape[0], 1, -1)
state_features = self.state_encoder(state, action_input.embodiment_id)
return BatchFeature(
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
)
@torch.no_grad()
def get_action_with_features(
self,
backbone_features: torch.Tensor,
state_features: torch.Tensor,
embodiment_id: torch.Tensor,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
vl_embeds = backbone_features
batch_size = vl_embeds.shape[0]
device = vl_embeds.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=vl_embeds.dtype,
device=device,
)
dt = 1.0 / self.num_inference_timesteps
vel_strength = torch.ones_like(actions)
if "action" in action_input:
if options is None:
raise ValueError("RTC options are required when action is provided to get_action.")
action_horizon_before_padding = options["action_horizon"]
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
:,
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
:,
]
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
ramp = ramp / ramp[-1].clamp_min(1e-8)
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
None, :, None
].to(device)
for t_step in range(self.num_inference_timesteps):
t_cont = t_step / float(self.num_inference_timesteps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
@torch.no_grad()
def get_action(
self,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
features = self._encode_features(backbone_output, action_input)
return self.get_action_with_features(
backbone_features=features.backbone_features,
state_features=features.state_features,
embodiment_id=action_input.embodiment_id,
backbone_output=backbone_output,
action_input=action_input,
options=options,
)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
if Qwen3VLConfig is None:
raise ImportError(
"Qwen3VLConfig is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
return Qwen3VLConfig(
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=True,
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
)
def get_backbone_cls(config: GR00TN17Config):
if (
config.backbone_model_type == "qwen"
or "nvidia/Cosmos-Reason2" in config.model_name
or "Qwen/Qwen3-VL" in config.model_name
):
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
class GR00TN17(PreTrainedModel):
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
config_class = GR00TN17Config
supports_gradient_checkpointing = True
def __init__(
self,
config: GR00TN17Config,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_backbone_weights: bool = True,
):
super().__init__(config)
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
self.config = config
backbone_cls = get_backbone_cls(config)
self.backbone = backbone_cls(
model_name=config.model_name,
tune_llm=config.tune_llm,
tune_visual=config.tune_visual,
select_layer=config.select_layer,
reproject_vision=config.reproject_vision,
use_flash_attention=config.use_flash_attention,
load_bf16=config.load_bf16,
tune_top_llm_layers=config.tune_top_llm_layers,
trainable_params_fp32=config.backbone_trainable_params_fp32,
transformers_loading_kwargs=transformers_loading_kwargs,
load_pretrained_weights=load_backbone_weights,
)
self.action_head = GR00TN17ActionHead(config)
self.post_init()
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
global tree
if tree is None:
require_package("dm-tree", extra="groot", import_name="tree")
tree = importlib.import_module("tree")
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_dtype(x):
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
return x.to(self.device, dtype=self.dtype)
return x.to(self.device)
return (
tree.map_structure(to_device_with_dtype, backbone_inputs),
tree.map_structure(to_device_with_dtype, action_inputs),
)
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head(backbone_outputs, action_inputs)
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head.get_action(backbone_outputs, action_inputs, options)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
tune_vlln = kwargs.pop("tune_vlln", True)
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("revision", "cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
try:
local_model_path = snapshot_download(
pretrained_model_name_or_path,
repo_type="model",
revision=kwargs.get("revision"),
cache_dir=kwargs.get("cache_dir"),
local_files_only=kwargs.get("local_files_only", False),
token=kwargs.get("token"),
)
except (HFValidationError, RepositoryNotFoundError):
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path,
transformers_loading_kwargs=transformers_loading_kwargs,
load_backbone_weights=load_backbone_weights,
**kwargs,
)
pretrained_model.backbone.set_trainable_parameters(
tune_visual=tune_visual,
tune_llm=tune_llm,
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector,
tune_diffusion_model=tune_diffusion_model,
tune_vlln=tune_vlln,
)
return pretrained_model
def _register_with_transformers() -> None:
if AutoConfig is None or AutoModel is None:
return
try:
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
try:
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoModel.register(GR00TN17Config, GR00TN17)
_register_with_transformers()
+49 -221
View File
@@ -17,8 +17,14 @@
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code.
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
@@ -40,15 +46,8 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
from .configuration_groot import GrootConfig
from .groot_n1 import GR00TN15
T = TypeVar("T", bound="GrootPolicy")
@@ -68,28 +67,27 @@ class GrootPolicy(PreTrainedPolicy):
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self._action_queue_steps = self._resolve_action_queue_steps()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model_kwargs = {
"pretrained_model_name_or_path": self.config.base_model_path,
"tune_llm": self.config.tune_llm,
"tune_visual": self.config.tune_visual,
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
}
from .groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
transformers_loading_kwargs={"trust_remote_code": True},
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
@@ -99,7 +97,7 @@ class GrootPolicy(PreTrainedPolicy):
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self._action_queue_steps)
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
@@ -120,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T N1.7 models - loads the raw model
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
@@ -143,13 +141,8 @@ class GrootPolicy(PreTrainedPolicy):
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
requested_version = (
normalize_groot_model_version(config.model_version)
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
@@ -200,12 +193,8 @@ class GrootPolicy(PreTrainedPolicy):
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
# Create default config with the pretrained path
config = GrootConfig(
model_version=model_version,
base_model_path=str(pretrained_name_or_path),
)
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
@@ -226,13 +215,6 @@ class GrootPolicy(PreTrainedPolicy):
if hasattr(config, key):
setattr(config, key, value)
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
raise ValueError(
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -243,164 +225,21 @@ class GrootPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
execution_horizon = infer_groot_n1_7_action_execution_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
horizons = [n_action_steps]
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
horizons = [actions.shape[1]]
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
for horizon in (self.config.chunk_size, self.config.n_action_steps):
horizon = int(horizon)
if horizon > 0:
horizons.append(horizon)
return max(1, min(horizons))
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
allowed_base = {"state", "state_mask", "embodiment_id"}
if include_action:
allowed_base.update({"action", "action_mask"})
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
return {
k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
self,
inputs: dict[str, Tensor],
*,
inference_delay: object,
prev_chunk_left_over: object,
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if prev_chunk_left_over is None:
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
if prev_chunk_left_over.numel() == 0:
return inputs, None
prev_actions = prev_chunk_left_over
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state")
if state is None:
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
batch_size = state.shape[0]
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
# provided prefix row as a real action constraint, so strip that padding
# before constructing the native overlap options.
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
if valid_prefix_rows.any():
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
prev_actions = prev_actions[:, :valid_prefix_steps, :]
else:
return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
action_horizon = int(prev_actions.shape[1])
if action_horizon <= 0:
return inputs, None
if prev_actions.shape[2] > max_action_dim:
prev_actions = prev_actions[:, :, :max_action_dim]
elif prev_actions.shape[2] < max_action_dim:
pad = torch.zeros(
prev_actions.shape[0],
prev_actions.shape[1],
max_action_dim - prev_actions.shape[2],
dtype=prev_actions.dtype,
device=prev_actions.device,
)
prev_actions = torch.cat([prev_actions, pad], dim=2)
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
rtc_config = getattr(self.config, "rtc_config", None)
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
overlap_steps = max(0, min(action_horizon, execution_horizon))
if overlap_steps == 0:
return inputs, None
try:
frozen_steps = int(inference_delay or 0)
except (TypeError, ValueError):
frozen_steps = 0
frozen_steps = max(0, min(frozen_steps, overlap_steps))
options = {
"action_horizon": action_horizon,
"rtc_overlap_steps": overlap_steps,
"rtc_frozen_steps": frozen_steps,
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
}
inputs = dict(inputs)
inputs["action"] = prev_actions
return inputs, options
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = get_device_from_parameters(self)
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
@@ -415,43 +254,32 @@ class GrootPolicy(PreTrainedPolicy):
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
action-overlap options before calling the underlying model.
"""
self.eval()
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
groot_options = None
if self.config.model_version == GROOT_N1_7:
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = get_device_from_parameters(self)
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
if groot_options is not None:
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
else:
outputs = self._groot_model.get_action(groot_inputs)
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
prediction_horizon = self._resolve_prediction_horizon(actions)
actions = actions[:, :prediction_horizon]
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
@@ -464,7 +292,7 @@ class GrootPolicy(PreTrainedPolicy):
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
File diff suppressed because it is too large Load Diff
+47
View File
@@ -0,0 +1,47 @@
from pathlib import Path
from shutil import copytree
from huggingface_hub import hf_hub_download
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)
+1 -1
View File
@@ -1 +1 @@
../../../../docs/source/policy_molmoact2_README.md
../../../../docs/source/molmoact2.mdx
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,16 +14,9 @@
from __future__ import annotations
import json
import math
import os
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import (
AdamWConfig,
@@ -37,146 +28,6 @@ from lerobot.utils.constants import ACTION, OBS_STATE
from ..rtc.configuration_rtc import RTCConfig
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
@LRSchedulerConfig.register_subclass("molmoact2_cosine_decay_with_warmup")
@dataclass
class MolmoAct2CosineDecayWithWarmupSchedulerConfig(CosineDecayWithWarmupSchedulerConfig):
"""MolmoAct2-local cosine scheduler with optional decay-step auto-match.
LeRobot's generic cosine scheduler keeps an explicit integer decay length.
For MolmoAct2, leaving num_decay_steps unset means "decay across this run's
training steps"; build() is the first point where num_training_steps is known.
"""
num_decay_steps: int | None
def build(self, optimizer, num_training_steps: int):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.peak_lr,
decay_lr=self.decay_lr,
num_warmup_steps=self.num_warmup_steps,
num_decay_steps=num_training_steps if self.num_decay_steps is None else self.num_decay_steps,
).build(optimizer, num_training_steps=num_training_steps)
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
@PreTrainedConfig.register_subclass("molmoact2")
@dataclass
@@ -255,7 +106,7 @@ class MolmoAct2Config(PreTrainedConfig):
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 200
scheduler_decay_steps: int | None = None
scheduler_decay_steps: int = 100_000
scheduler_decay_lr: float = 1e-6
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -333,41 +184,6 @@ class MolmoAct2Config(PreTrainedConfig):
if self.max_sequence_length is not None and self.max_sequence_length < 1:
raise ValueError(f"max_sequence_length must be >= 1 or None, got {self.max_sequence_length}.")
def inferred_max_sequence_length(
self,
*,
num_images: int | None = None,
state_dim: int | None = None,
action_dim: int | None = None,
action_horizon: int | None = None,
include_discrete_action: bool | None = None,
) -> int:
if self.max_sequence_length is not None:
return int(self.max_sequence_length)
if num_images is None:
num_images = len(self.image_keys) or len(self.image_features) or MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim is None:
state_feature = self.robot_state_feature
state_dim = int(state_feature.shape[0]) if state_feature is not None else 0
if action_dim is None:
action_feature = self.action_feature
action_dim = (
int(action_feature.shape[0]) if action_feature is not None else self.expected_max_action_dim
)
if action_horizon is None:
action_horizon = self.chunk_size
if include_discrete_action is None:
include_discrete_action = self.action_mode in {"discrete", "both"}
return infer_molmoact2_max_sequence_length(
num_images=int(num_images),
state_dim=int(state_dim),
action_dim=int(action_dim),
action_horizon=int(action_horizon),
include_discrete_action=bool(include_discrete_action),
)
@property
def observation_delta_indices(self) -> None:
return None
@@ -390,7 +206,7 @@ class MolmoAct2Config(PreTrainedConfig):
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return MolmoAct2CosineDecayWithWarmupSchedulerConfig(
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
@@ -426,94 +242,3 @@ class MolmoAct2Config(PreTrainedConfig):
shape=(self.expected_max_action_dim,),
)
self.output_features[ACTION] = action_feature
def apply_norm_tag_metadata(self) -> None:
if not str(self.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
self.checkpoint_path,
revision=self.checkpoint_revision,
force_download=bool(self.checkpoint_force_download),
norm_tag=self.norm_tag,
)
if metadata.get("action_horizon") is not None:
self.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
self.n_action_steps = int(metadata["n_action_steps"])
if not self.setup_type and metadata.get("setup_type") is not None:
self.setup_type = str(metadata["setup_type"])
if not self.control_mode and metadata.get("control_mode") is not None:
self.control_mode = str(metadata["control_mode"])
def saved_policy_action_mode(self) -> str | None:
pretrained_path = getattr(self, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def training_action_mode(self, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or self.action_mode
def validate_inference_action_mode(self, saved_policy_action_mode: str | None = None) -> None:
requested_mode = self.inference_action_mode
if requested_mode is None:
return
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def validate_checkpoint_action_mode(
self,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
if self.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if self.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if self.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def resolve_inference_action_mode(
self,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
training_mode = self.training_action_mode(saved_policy_action_mode)
if requested_mode is None:
requested_mode = self.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,9 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""MolmoAct2 policy for LeRobot.
MolmoAct2 is a VLM-based robotics policy from Allen AI that combines a
Molmo vision-language backbone with a per-layer flow-matching action expert
for continuous action generation, plus an optional discrete action token
head. This module wraps the vendored HF model implementation
(``molmoact2_hf_model/``) into the LeRobot ``PreTrainedPolicy`` interface.
Paper: https://allenai.org/blog/molmoact2
Code: https://github.com/allenai/molmoact2
"""
from __future__ import annotations
import json
import logging
import os
import types
from collections import deque
@@ -35,13 +46,58 @@ from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from ..rtc.modeling_rtc import RTCProcessor
from .configuration_molmoact2 import MolmoAct2Config, _hf_token, _resolve_checkpoint_location
from .configuration_molmoact2 import MolmoAct2Config
logger = logging.getLogger(__name__)
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
"""Resolve a checkpoint path to a local directory, downloading from Hub if needed."""
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
from pathlib import Path
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _torch_dtype(dtype: str) -> torch.dtype:
"""Convert a dtype name string to a torch.dtype."""
if dtype == "float32":
return torch.float32
if dtype == "bfloat16":
return torch.bfloat16
if dtype == "float16":
return torch.float16
raise ValueError(f"Unsupported dtype: {dtype}")
if TYPE_CHECKING or _transformers_available:
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from .hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
from .hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
from .molmoact2_hf_model.configuration_molmoact2 import MolmoAct2Config as HFMolmoAct2Config
from .molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
else:
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
@@ -49,7 +105,7 @@ else:
MolmoAct2ForConditionalGeneration = None
if TYPE_CHECKING or (_transformers_available and _scipy_available):
from .hf_model.action_tokenizer import UniversalActionProcessor
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
else:
UniversalActionProcessor = None
@@ -70,6 +126,156 @@ _MODEL_INPUT_KEYS = {
}
def _load_hf_norm_metadata_for_tag(
checkpoint_path: str,
*,
revision: str | None,
force_download: bool,
norm_tag: str | None,
) -> dict[str, Any]:
"""Read per-tag metadata from the checkpoint's ``norm_stats.json``."""
norm_tag = str(norm_tag or "").strip()
if not norm_tag:
return {}
from contextlib import suppress
from pathlib import Path
checkpoint_location = Path(
_resolve_checkpoint_location(
checkpoint_path,
revision=revision,
force_download=force_download,
)
)
norm_stats_filename = "norm_stats.json"
config_path = checkpoint_location / "config.json"
if config_path.exists():
with suppress(OSError, json.JSONDecodeError):
norm_stats_filename = str(
json.loads(config_path.read_text()).get("norm_stats_filename") or norm_stats_filename
)
stats_path = checkpoint_location / norm_stats_filename
if not stats_path.exists():
raise FileNotFoundError(
f"MolmoAct2 HF checkpoint is missing {norm_stats_filename!r}; cannot resolve norm_tag={norm_tag!r}."
)
payload = json.loads(stats_path.read_text())
metadata_by_tag = payload.get("metadata_by_tag")
if not isinstance(metadata_by_tag, dict):
raise ValueError(f"MolmoAct2 norm stats file {stats_path} has no metadata_by_tag mapping.")
metadata = metadata_by_tag.get(norm_tag)
if not isinstance(metadata, dict):
available = sorted(str(tag) for tag in metadata_by_tag)
raise ValueError(f"Unknown MolmoAct2 norm_tag={norm_tag!r}. Available tags: {available}.")
return metadata
def _apply_norm_tag_metadata(config: MolmoAct2Config) -> None:
"""Populate config fields from the checkpoint's norm-tag metadata."""
if not str(config.norm_tag or "").strip():
return
metadata = _load_hf_norm_metadata_for_tag(
config.checkpoint_path,
revision=config.checkpoint_revision,
force_download=bool(config.checkpoint_force_download),
norm_tag=config.norm_tag,
)
if metadata.get("action_horizon") is not None:
config.chunk_size = int(metadata["action_horizon"])
if metadata.get("n_action_steps") is not None:
config.n_action_steps = int(metadata["n_action_steps"])
if not config.setup_type and metadata.get("setup_type") is not None:
config.setup_type = str(metadata["setup_type"])
if not config.control_mode and metadata.get("control_mode") is not None:
config.control_mode = str(metadata["control_mode"])
def _saved_policy_action_mode(config: MolmoAct2Config) -> str | None:
"""Read the action mode from a LeRobot-saved checkpoint's ``config.json``."""
from pathlib import Path
pretrained_path = getattr(config, "pretrained_path", None)
if pretrained_path is None:
return None
config_path = Path(pretrained_path) / "config.json"
if not config_path.exists():
return None
try:
mode = json.loads(config_path.read_text()).get("action_mode")
except (OSError, json.JSONDecodeError):
return None
if mode in {"continuous", "discrete", "both"}:
return str(mode)
return None
def _training_action_mode(config: MolmoAct2Config, saved_policy_action_mode: str | None = None) -> str:
return saved_policy_action_mode or config.action_mode
def _validate_inference_action_mode(
config: MolmoAct2Config, saved_policy_action_mode: str | None = None
) -> None:
"""Check that the requested inference mode is compatible with the training mode."""
requested_mode = config.inference_action_mode
if requested_mode is None:
return
training_mode = _training_action_mode(config, saved_policy_action_mode)
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='discrete' and cannot run "
"continuous inference."
)
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError(
"MolmoAct2 checkpoint was trained with action_mode='continuous' and cannot run "
"discrete inference. Train with action_mode='both' or action_mode='discrete' first."
)
def _validate_checkpoint_action_mode(
config: MolmoAct2Config,
checkpoint_action_mode: str,
*,
has_action_expert: bool,
) -> None:
"""Check that the checkpoint's action mode is compatible with the config."""
if config.action_mode == "both" and checkpoint_action_mode != "both":
raise ValueError(
f"action_mode='both' requires checkpoint action_mode='both', got {checkpoint_action_mode!r}."
)
if config.action_mode == "discrete" and checkpoint_action_mode not in {"discrete", "both"}:
raise ValueError(
f"action_mode='discrete' requires checkpoint action_mode in {{'discrete', 'both'}}, "
f"got {checkpoint_action_mode!r}."
)
if config.action_mode in {"continuous", "both"} and not has_action_expert:
raise ValueError("Continuous MolmoAct2 training requires an action expert checkpoint.")
def _resolve_inference_action_mode(
config: MolmoAct2Config,
requested_mode: str | None,
saved_policy_action_mode: str | None = None,
) -> str:
"""Resolve the final inference action mode, validating compatibility."""
training_mode = _training_action_mode(config, saved_policy_action_mode)
if requested_mode is None:
requested_mode = config.inference_action_mode
if requested_mode is None:
raise ValueError(
"MolmoAct2 inference requires `inference_action_mode` to be set explicitly "
"to either 'continuous' or 'discrete'."
)
if requested_mode not in {"continuous", "discrete"}:
raise ValueError("MolmoAct2 inference_action_mode must be either 'continuous' or 'discrete'.")
if requested_mode == "continuous" and training_mode == "discrete":
raise ValueError("MolmoAct2 action_mode='discrete' checkpoint cannot run continuous inference.")
if requested_mode == "discrete" and training_mode == "continuous":
raise ValueError("MolmoAct2 action_mode='continuous' checkpoint cannot run discrete inference.")
return requested_mode
def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location: str) -> None:
index_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_INDEX_NAME)
single_file_path = os.path.join(checkpoint_location, SAFE_WEIGHTS_NAME)
@@ -103,16 +309,6 @@ def _strict_load_safetensors_weights(model: torch.nn.Module, checkpoint_location
)
def _torch_dtype(dtype: str) -> torch.dtype:
if dtype == "float32":
return torch.float32
if dtype == "bfloat16":
return torch.bfloat16
if dtype == "float16":
return torch.float16
raise ValueError(f"Unsupported dtype: {dtype}")
def _sample_beta_timesteps(
*,
batch_size: int,
@@ -136,7 +332,180 @@ def _sample_beta_timesteps(
return time_offset + scale * samples
def _mask_discrete_action_spans(
*,
input_ids: Tensor,
mask: Tensor,
start_token_id: int | None,
end_token_id: int | None,
) -> Tensor:
if start_token_id is None or end_token_id is None:
return mask
mask = mask.clone()
for batch_idx in range(input_ids.shape[0]):
row = input_ids[batch_idx]
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
end_ptr = 0
for start in starts:
while end_ptr < len(ends) and ends[end_ptr] < start:
end_ptr += 1
if end_ptr >= len(ends):
mask[batch_idx, start:] = False
break
end = int(ends[end_ptr])
mask[batch_idx, start : end + 1] = False
end_ptr += 1
return mask
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
attention_mask = model_inputs.get("attention_mask")
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
model_inputs = dict(model_inputs)
model_inputs.pop("attention_mask", None)
return model_inputs
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
if mask is None:
return None
return (
mask.unsqueeze(1)
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
)
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
if action_dim_is_pad is None:
return None
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
if mask.shape[-1] != target.shape[-1]:
raise ValueError(
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
)
if mask.shape[0] == 1 and target.shape[0] != 1:
mask = mask.expand(target.shape[0], -1)
if mask.shape[0] != target.shape[0]:
raise ValueError(
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
)
while mask.ndim < target.ndim:
mask = mask.unsqueeze(1)
return mask
def _mask_action_dim_tensor(tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
if action_dim_is_pad is None:
return tensor
valid_mask = _action_dim_valid_mask(tensor, action_dim_is_pad)
if valid_mask is None:
return tensor
return tensor.masked_fill(~valid_mask, 0)
def _apply_action_dim_padding_mask(loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
valid_mask = _action_dim_valid_mask(loss, action_dim_is_pad)
if valid_mask is None:
return loss
valid = valid_mask.to(dtype=loss.dtype)
denom = valid.sum(dim=-1).clamp_min(1.0)
return (loss * valid).sum(dim=-1) / denom
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
if action_horizon_is_pad is None:
return loss
valid_action = (
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
)
return loss * valid_action
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
seed = 0
for idx in range(batch_size):
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
return seed
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
task = batch.get("task")
if task is None:
task = batch.get("observation.language")
if task is None:
return None
if isinstance(task, str):
return (task,)
if isinstance(task, (list, tuple)):
return tuple(str(item) for item in task)
return (str(task),)
def _extract_discrete_token_bins(
generated_ids: list[int],
start_token_id: int,
end_token_id: int,
token_id_to_bin: dict[int, int],
) -> list[int]:
start_idx = None
end_idx = None
for idx, token_id in enumerate(generated_ids):
if token_id == start_token_id:
start_idx = idx
break
if start_idx is not None:
for idx in range(start_idx + 1, len(generated_ids)):
if generated_ids[idx] == end_token_id:
end_idx = idx
break
span_start = 0 if start_idx is None else start_idx + 1
span_end = len(generated_ids) if end_idx is None else end_idx
return [
int(token_id_to_bin[token_id])
for token_id in generated_ids[span_start:span_end]
if token_id in token_id_to_bin
]
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
if weights is None:
return values.mean()
weights = weights.to(device=values.device, dtype=values.dtype)
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
def _weighted_per_example(
values: Tensor,
weights: Tensor | None,
example_indices: Tensor,
batch_size: int,
) -> Tensor:
values = values.float()
if weights is None:
weights = torch.ones_like(values)
else:
weights = weights.to(device=values.device, dtype=values.dtype)
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
loss_sum.scatter_add_(0, example_indices, values * weights)
weight_sum.scatter_add_(0, example_indices, weights)
global_weight_sum = weight_sum.sum().clamp_min(1.0)
return loss_sum * float(batch_size) / global_weight_sum
class MolmoAct2Policy(PreTrainedPolicy):
"""MolmoAct2 policy wrapping the vendored HF model for LeRobot.
Supports three training modes via ``config.action_mode``:
``"continuous"`` (flow-matching only), ``"discrete"`` (autoregressive
token prediction only), or ``"both"`` (joint loss). At inference,
``config.inference_action_mode`` selects which head generates actions.
"""
config_class = MolmoAct2Config
name = "molmoact2"
@@ -149,10 +518,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
**kwargs,
):
super().__init__(config, *inputs, **kwargs)
self.config.apply_norm_tag_metadata()
_apply_norm_tag_metadata(self.config)
self.config.validate_features()
del inputs, kwargs, dataset_stats, dataset_meta
self._checkpoint_action_mode = self.config.saved_policy_action_mode()
self._checkpoint_action_mode = _saved_policy_action_mode(self.config)
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator: torch.Generator | None = None
self._rollout_task_key: tuple[Any, ...] | None = None
@@ -160,7 +529,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.rtc_processor: RTCProcessor | None = None
self.action_tokenizer: Any | None = None
self._load_hf_model()
self.config.validate_inference_action_mode(self._checkpoint_action_mode)
_validate_inference_action_mode(self.config, self._checkpoint_action_mode)
if self.config.enable_lora_vlm:
self._apply_lora_adapters()
self.init_rtc_processor()
@@ -212,7 +581,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
"`policy.checkpoint_force_download=true` after the updated files are pushed."
)
checkpoint_action_mode = str(self.model.config.action_mode)
self.config.validate_checkpoint_action_mode(
_validate_checkpoint_action_mode(
self.config,
checkpoint_action_mode,
has_action_expert=bool(getattr(self.model.config, "add_action_expert", False)),
)
@@ -226,6 +596,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.train(self.training)
def reset(self) -> None:
"""Clear the action queue and rollout generator between episodes."""
self._action_queue = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator = None
@@ -334,6 +705,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
param.requires_grad = False
def get_optim_params(self) -> list[dict[str, Any]]:
"""Return optimizer param groups with per-component learning rates."""
vit_params: list[Tensor] = []
connector_params: list[Tensor] = []
action_expert_params: list[Tensor] = []
@@ -419,33 +791,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
return int(value)
raise RuntimeError("MolmoAct2 could not resolve an action generation horizon.")
@staticmethod
def _mask_discrete_action_spans(
*,
input_ids: Tensor,
mask: Tensor,
start_token_id: int | None,
end_token_id: int | None,
) -> Tensor:
if start_token_id is None or end_token_id is None:
return mask
mask = mask.clone()
for batch_idx in range(input_ids.shape[0]):
row = input_ids[batch_idx]
starts = (row == int(start_token_id)).nonzero(as_tuple=False).flatten().tolist()
ends = (row == int(end_token_id)).nonzero(as_tuple=False).flatten().tolist()
end_ptr = 0
for start in starts:
while end_ptr < len(ends) and ends[end_ptr] < start:
end_ptr += 1
if end_ptr >= len(ends):
mask[batch_idx, start:] = False
break
end = int(ends[end_ptr])
mask[batch_idx, start : end + 1] = False
end_ptr += 1
return mask
def _encoder_attention_mask_for_action_expert(
self,
*,
@@ -470,21 +815,13 @@ class MolmoAct2Policy(PreTrainedPolicy):
eos_token_id = getattr(self.model.config, "eos_token_id", None)
if eos_token_id is not None:
mask &= input_ids != int(eos_token_id)
return self._mask_discrete_action_spans(
return _mask_discrete_action_spans(
input_ids=input_ids,
mask=mask,
start_token_id=getattr(self.model.config, "action_start_token_id", None),
end_token_id=getattr(self.model.config, "action_end_token_id", None),
)
@staticmethod
def _drop_trivial_attention_mask(model_inputs: dict[str, Tensor]) -> dict[str, Tensor]:
attention_mask = model_inputs.get("attention_mask")
if torch.is_tensor(attention_mask) and bool(attention_mask.to(dtype=torch.bool).all().item()):
model_inputs = dict(model_inputs)
model_inputs.pop("attention_mask", None)
return model_inputs
def _load_discrete_action_tokenizer(self) -> Any:
if self.action_tokenizer is None:
require_package("transformers", extra="molmoact2")
@@ -498,27 +835,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
return self.action_tokenizer
def _resolve_inference_action_mode(self, requested_mode: str | None) -> str:
return self.config.resolve_inference_action_mode(requested_mode, self._checkpoint_action_mode)
@staticmethod
def _combine_rollout_seeds(first_seed: int, batch_size: int) -> int:
seed = 0
for idx in range(batch_size):
seed = (seed + (idx + 1) * (first_seed + idx)) % (2**63 - 1)
return seed
@staticmethod
def _rollout_task_signature(batch: dict[str, Any]) -> tuple[Any, ...] | None:
task = batch.get("task")
if task is None:
task = batch.get("observation.language")
if task is None:
return None
if isinstance(task, str):
return (task,)
if isinstance(task, (list, tuple)):
return tuple(str(item) for item in task)
return (str(task),)
return _resolve_inference_action_mode(self.config, requested_mode, self._checkpoint_action_mode)
def _rollout_generator_for_inputs(
self,
@@ -532,7 +849,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
if self._rollout_action_generator is not None:
return self._rollout_action_generator
task_signature = self._rollout_task_signature(batch)
task_signature = _rollout_task_signature(batch)
if task_signature != self._rollout_task_key:
self._rollout_task_key = task_signature
self._rollout_index_for_task = 0
@@ -545,72 +862,10 @@ class MolmoAct2Policy(PreTrainedPolicy):
device if device.type == "cuda" and torch.cuda.is_available() else torch.device("cpu")
)
generator = torch.Generator(device=generator_device)
generator.manual_seed(self._combine_rollout_seeds(first_seed, batch_size))
generator.manual_seed(_combine_rollout_seeds(first_seed, batch_size))
self._rollout_action_generator = generator
return generator
@staticmethod
def _expand_mask(mask: Tensor | None, num_flow_timesteps: int) -> Tensor | None:
if mask is None:
return None
return (
mask.unsqueeze(1)
.expand(-1, num_flow_timesteps, *([-1] * (mask.ndim - 1)))
.reshape(mask.shape[0] * num_flow_timesteps, *mask.shape[1:])
)
@staticmethod
def _action_dim_valid_mask(target: Tensor, action_dim_is_pad: Tensor | None) -> Tensor | None:
if action_dim_is_pad is None:
return None
mask = ~action_dim_is_pad.to(device=target.device, dtype=torch.bool)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
if mask.shape[-1] != target.shape[-1]:
raise ValueError(
f"action_dim_is_pad width {mask.shape[-1]} does not match target width {target.shape[-1]}."
)
if mask.shape[0] == 1 and target.shape[0] != 1:
mask = mask.expand(target.shape[0], -1)
if mask.shape[0] != target.shape[0]:
raise ValueError(
f"action_dim_is_pad batch {mask.shape[0]} does not match target batch {target.shape[0]}."
)
while mask.ndim < target.ndim:
mask = mask.unsqueeze(1)
return mask
@classmethod
def _mask_action_dim_tensor(cls, tensor: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
if not cls._mask_enabled_static(action_dim_is_pad):
return tensor
valid_mask = cls._action_dim_valid_mask(tensor, action_dim_is_pad)
if valid_mask is None:
return tensor
return tensor.masked_fill(~valid_mask, 0)
@staticmethod
def _mask_enabled_static(action_dim_is_pad: Tensor | None) -> bool:
return action_dim_is_pad is not None
@classmethod
def _apply_action_dim_padding_mask(cls, loss: Tensor, action_dim_is_pad: Tensor | None) -> Tensor:
valid_mask = cls._action_dim_valid_mask(loss, action_dim_is_pad)
if valid_mask is None:
return loss
valid = valid_mask.to(dtype=loss.dtype)
denom = valid.sum(dim=-1).clamp_min(1.0)
return (loss * valid).sum(dim=-1) / denom
@staticmethod
def _apply_action_chunk_padding_mask(loss: Tensor, action_horizon_is_pad: Tensor | None) -> Tensor:
if action_horizon_is_pad is None:
return loss
valid_action = (
(~action_horizon_is_pad.to(device=loss.device, dtype=torch.bool)).unsqueeze(1).unsqueeze(-1)
)
return loss * valid_action
def _prepare_flow_matching_tensors(
self,
*,
@@ -649,7 +904,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
)
if self.config.mask_action_dim_padding:
actions = self._mask_action_dim_tensor(actions, action_dim_is_pad)
actions = _mask_action_dim_tensor(actions, action_dim_is_pad)
expected_noise_shape = (batch_size, num_flow_timesteps, actions.shape[1], actions.shape[2])
if noise is None:
@@ -661,7 +916,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
f"flow noise must have shape {expected_noise_shape}, got {tuple(noise.shape)}."
)
if self.config.mask_action_dim_padding:
noise = self._mask_action_dim_tensor(noise, action_dim_is_pad)
noise = _mask_action_dim_tensor(noise, action_dim_is_pad)
t_broadcast = timesteps.view(batch_size, num_flow_timesteps, 1, 1)
actions_expanded = actions.unsqueeze(1).expand(-1, num_flow_timesteps, -1, -1)
@@ -789,7 +1044,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
valid_action = None
if action_attention_mask is not None:
valid_action = action_attention_mask.to(device=device, dtype=actions.dtype).unsqueeze(-1)
valid_action = self._expand_mask(valid_action, num_flow_timesteps)
valid_action = _expand_mask(valid_action, num_flow_timesteps)
rope_cache = None
if len(action_expert.blocks) > 0 and action_expert.blocks[0].self_attn.rope is not None:
@@ -804,14 +1059,14 @@ class MolmoAct2Policy(PreTrainedPolicy):
batch_size,
actions.dtype,
)
cross_mask = self._expand_mask(cross_mask, num_flow_timesteps)
cross_mask = _expand_mask(cross_mask, num_flow_timesteps)
self_mask = action_expert._build_self_attention_mask(
action_attention_mask,
actions.shape[1],
device,
actions.dtype,
)
self_mask = self._expand_mask(self_mask, num_flow_timesteps)
self_mask = _expand_mask(self_mask, num_flow_timesteps)
conditioning = self._action_time_conditioning(action_expert, timesteps_flat)
action_hidden = action_expert.action_embed(xt_flat)
@@ -871,8 +1126,8 @@ class MolmoAct2Policy(PreTrainedPolicy):
if k_norm is not None:
k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
if num_flow_timesteps != 1:
k_ctx = self._expand_mask(k_ctx, num_flow_timesteps)
v_ctx = self._expand_mask(v_ctx, num_flow_timesteps)
k_ctx = _expand_mask(k_ctx, num_flow_timesteps)
v_ctx = _expand_mask(v_ctx, num_flow_timesteps)
next_action_hidden = action_block(
layer_action_hidden,
@@ -912,9 +1167,9 @@ class MolmoAct2Policy(PreTrainedPolicy):
)
loss = F.mse_loss(pred_velocity, target_velocity, reduction="none")
loss = self._apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
loss = _apply_action_chunk_padding_mask(loss, batch.get("action_horizon_is_pad"))
if self.config.mask_action_dim_padding:
loss = self._apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
loss = _apply_action_dim_padding_mask(loss, batch.get("action_dim_is_pad"))
loss = loss.reshape(batch_size, -1).mean(dim=1)
if reduction == "mean":
loss = loss.mean()
@@ -933,32 +1188,6 @@ class MolmoAct2Policy(PreTrainedPolicy):
example_weights[nonempty] = 2.0 / torch.sqrt(token_counts[nonempty])
return example_weights[:, None].expand_as(valid_positions)[valid_positions].to(dtype=torch.float32)
@staticmethod
def _weighted_mean(values: Tensor, weights: Tensor | None) -> Tensor:
if weights is None:
return values.mean()
weights = weights.to(device=values.device, dtype=values.dtype)
return torch.dot(values, weights) / weights.sum().clamp_min(1.0)
@staticmethod
def _weighted_per_example(
values: Tensor,
weights: Tensor | None,
example_indices: Tensor,
batch_size: int,
) -> Tensor:
values = values.float()
if weights is None:
weights = torch.ones_like(values)
else:
weights = weights.to(device=values.device, dtype=values.dtype)
loss_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
weight_sum = torch.zeros(batch_size, device=values.device, dtype=torch.float32)
loss_sum.scatter_add_(0, example_indices, values * weights)
weight_sum.scatter_add_(0, example_indices, weights)
global_weight_sum = weight_sum.sum().clamp_min(1.0)
return loss_sum * float(batch_size) / global_weight_sum
def _discrete_loss_from_backbone_outputs(
self,
batch: dict[str, Tensor],
@@ -992,56 +1221,28 @@ class MolmoAct2Policy(PreTrainedPolicy):
token_weights = self._discrete_token_weights(valid_positions)
if reduction == "none":
example_indices = valid_positions.nonzero(as_tuple=False)[:, 0].to(device=hidden_states.device)
ce_loss = self._weighted_per_example(
ce_loss = _weighted_per_example(
token_ce_loss,
token_weights,
example_indices,
int(labels.shape[0]),
)
else:
ce_loss = self._weighted_mean(token_ce_loss, token_weights)
ce_loss = _weighted_mean(token_ce_loss, token_weights)
if not self.config.softmax_auxiliary_loss:
return ce_loss, None
if reduction == "none":
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_per_example(
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_per_example(
log_z.pow(2),
token_weights,
example_indices,
int(labels.shape[0]),
)
else:
z_loss = self.config.softmax_auxiliary_loss_scale * self._weighted_mean(
log_z.pow(2), token_weights
)
z_loss = self.config.softmax_auxiliary_loss_scale * _weighted_mean(log_z.pow(2), token_weights)
return ce_loss, z_loss
@staticmethod
def _extract_discrete_token_bins(
generated_ids: list[int],
start_token_id: int,
end_token_id: int,
token_id_to_bin: dict[int, int],
) -> list[int]:
start_idx = None
end_idx = None
for idx, token_id in enumerate(generated_ids):
if token_id == start_token_id:
start_idx = idx
break
if start_idx is not None:
for idx in range(start_idx + 1, len(generated_ids)):
if generated_ids[idx] == end_token_id:
end_idx = idx
break
span_start = 0 if start_idx is None else start_idx + 1
span_end = len(generated_ids) if end_idx is None else end_idx
return [
int(token_id_to_bin[token_id])
for token_id in generated_ids[span_start:span_end]
if token_id in token_id_to_bin
]
def _action_token_id_to_bin(self) -> dict[int, int]:
method = getattr(self.model, "_action_token_id_to_bin", None)
if callable(method):
@@ -1179,7 +1380,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
chunks: list[Tensor] = []
for token_row in generated_token_ids:
generated_ids = [int(token_id) for token_id in token_row.detach().cpu().tolist()]
discrete_token_ids = self._extract_discrete_token_bins(
discrete_token_ids = _extract_discrete_token_bins(
generated_ids,
int(self.model.config.action_start_token_id),
int(self.model.config.action_end_token_id),
@@ -1218,7 +1419,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
model_inputs: dict[str, Tensor],
action_dim: int,
) -> Tensor:
model_inputs = self._drop_trivial_attention_mask(model_inputs)
model_inputs = _drop_trivial_attention_mask(model_inputs)
max_steps = self._discrete_generation_max_steps()
static_cache, attention_bias = self._make_discrete_ar_graph_decode_inputs(
model_inputs,
@@ -1294,7 +1495,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
generator=generator,
)
if self.config.mask_action_dim_padding:
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
action_context = action_expert.prepare_context(
encoder_kv_states=encoder_kv_states,
@@ -1327,7 +1528,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
modulation=step_modulation,
)
if mask_enabled:
velocity = self._mask_action_dim_tensor(velocity, action_dim_is_pad)
velocity = _mask_action_dim_tensor(velocity, action_dim_is_pad)
return velocity
if self._rtc_enabled():
@@ -1352,7 +1553,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
trajectory = trajectory + dt * velocity
if mask_enabled:
trajectory = self._mask_action_dim_tensor(trajectory, action_dim_is_pad)
trajectory = _mask_action_dim_tensor(trajectory, action_dim_is_pad)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=float(flow_timestep[0].item()), x_t=trajectory, v_t=velocity)
@@ -1363,6 +1564,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
batch: dict[str, Tensor],
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Compute training loss (flow-matching and/or discrete token loss)."""
if reduction not in {"mean", "none"}:
raise ValueError(f"Unsupported reduction={reduction!r}. Expected 'mean' or 'none'.")
model_inputs = self._model_inputs(batch)
@@ -1422,6 +1624,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Generate an action chunk via continuous flow matching or discrete AR decoding."""
if "action_mode" in kwargs:
raise TypeError(
"MolmoAct2 predict_action_chunk got unexpected keyword argument 'action_mode'; "
@@ -1476,6 +1679,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
"""Pop one action step from the queue, regenerating the chunk when empty."""
if self._rtc_enabled():
raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk")
self.eval()
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,23 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
import logging
import os
from pathlib import Path
from typing import ClassVar
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
from ..modeling_molmoact2 import _hf_token
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
logger = logging.getLogger(__name__)
def _resolve_tokenizer_location(
@@ -42,6 +36,8 @@ def _resolve_tokenizer_location(
local_path = Path(str(tokenizer_path)).expanduser()
if local_path.exists():
return str(local_path)
from huggingface_hub import snapshot_download
return snapshot_download(
repo_id=str(tokenizer_path),
repo_type="model",
@@ -134,9 +130,8 @@ class UniversalActionProcessor(ProcessorMixin):
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
except Exception:
logger.warning("Error decoding tokens: %s", token, exc_info=True)
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
MolmoAct2 configuration
"""
from typing import Optional, Any
from typing import Any
from transformers import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,33 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Image processor class for MolmoAct2"""
from typing import Optional, Union
import numpy as np
import einops
import numpy as np
import torch
import torchvision.transforms
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.image_transforms import convert_to_rgb
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ImageInput,
PILImageResampling,
make_flat_list_of_images,
valid_images,
to_numpy_array,
valid_images,
)
from transformers.image_transforms import convert_to_rgb
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import TensorType, logging
logger = logging.get_logger(__name__)
@@ -73,8 +66,8 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
assert image.dtype == torch.uint8, (
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
)
in_min = 0.0
in_max = 255.0
@@ -96,7 +89,6 @@ def resize_image(
def select_tiling(h, w, patch_size, max_num_crops):
"""Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_crops + 1):
for j in range(1, max_num_crops + 1):
@@ -406,13 +398,17 @@ class MolmoAct2ImageProcessor(BaseImageProcessor):
image_std: float | list[float] | None = None,
do_convert_rgb: bool = True,
max_crops: int = 8,
overlap_margins: list[int] = [4, 4],
overlap_margins: list[int] | None = None,
crop_mode: str = "overlap-and-resize-c2",
patch_size: int = 14,
pooling_size: list[int] = [2, 2],
pooling_size: list[int] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if overlap_margins is None:
overlap_margins = [4, 4]
if pooling_size is None:
pooling_size = [2, 2]
size = size if size is not None else {"height": 378, "width": 378}
size = get_size_dict(size, default_to_square=True)
self.size = size
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Inference utilities for MolmoAct2"""
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any
import torch
from torch.nn import functional as F
from torch.nn import functional as F # noqa: N812
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
@@ -679,7 +676,7 @@ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
def _copy_context_(dst: Any, src: Any) -> None:
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts, strict=False):
dst_k.copy_(src_k)
dst_v.copy_(src_v)
if src.cross_mask is not None:
@@ -689,7 +686,7 @@ def _copy_context_(dst: Any, src: Any) -> None:
if src.valid_action is not None:
dst.valid_action.copy_(src.valid_action)
if src.rope_cache is not None:
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache, strict=False):
dst_tensor.copy_(src_tensor)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,24 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Modeling code for MolmoAct2"""
# ruff: noqa: N806
import json
import math
import os
import re
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Optional
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn import functional as F # noqa: N812
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
@@ -647,7 +646,7 @@ class ActionExpert(nn.Module):
f"got {len(encoder_kv_states)}."
)
kv_contexts = []
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states, strict=False):
k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
k_norm = block.cross_attn.k_norm
@@ -732,7 +731,7 @@ class ActionExpert(nn.Module):
timesteps: Sequence[torch.Tensor],
) -> Sequence[ActionExpertStepModulation]:
cache = []
for idx, step_t in enumerate(timesteps):
for _idx, step_t in enumerate(timesteps):
conditioning = self._time_conditioning(step_t)
block_modulations = []
for block in self.blocks:
@@ -786,8 +785,8 @@ class ActionExpert(nn.Module):
x = self.action_embed(actions)
if context.valid_action is not None:
x = x * context.valid_action
for idx, (block, kv_context, block_modulation) in enumerate(
zip(self.blocks, context.kv_contexts, block_modulations)
for _idx, (block, kv_context, block_modulation) in enumerate(
zip(self.blocks, context.kv_contexts, block_modulations, strict=False)
):
x = block(
x,
@@ -2874,7 +2873,7 @@ class MolmoAct2Model(MolmoAct2PreTrainedModel):
depth_mask=depth_mask,
encoder_attention_mask=encoder_attention_mask,
)
for gate, source in zip(gate_head, sources)
for gate, source in zip(gate_head, sources, strict=False)
]
return gates, depth_mask
gate = self._depth_gate_from_source(
@@ -4458,7 +4457,7 @@ class MolmoAct2ForConditionalGeneration(MolmoAct2PreTrainedModel, GenerationMixi
```python
>>> from PIL import Image
>>> import requests
>>> from lerobot.policies.molmoact2.hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
>>> from lerobot.policies.molmoact2.molmoact2_hf_model.modeling_molmoact2 import MolmoAct2ForConditionalGeneration
>>> from lerobot.policies.molmoact2.processor_molmoact2 import _load_local_molmoact2_processor
>>> model = MolmoAct2ForConditionalGeneration.from_pretrained("...")
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,45 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""
Processor class for MolmoAct2.
"""
from typing import Optional, Union
import dataclasses
import numpy as np
from transformers import AutoTokenizer
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import (
Unpack,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.feature_extraction_utils import BatchFeature
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
from transformers import AutoTokenizer
from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
from .image_processing_molmoact2 import MolmoAct2ImageProcessor, MolmoAct2ImagesKwargs
from .video_processing_molmoact2 import MolmoAct2VideoProcessor, MolmoAct2VideoProcessorKwargs
logger = logging.get_logger(__name__)
# Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
IM_START_TOKEN = f"<im_start>"
LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
FRAME_START_TOKEN = f"<frame_start>"
IM_END_TOKEN = f"<im_end>"
FRAME_END_TOKEN = f"<frame_end>"
IM_COL_TOKEN = f"<im_col>"
IMAGE_PATCH_TOKEN = "<im_patch>" # nosec B105 # Where to insert high-res tokens
IMAGE_LOW_RES_TOKEN = "<im_low>" # nosec B105 # Where to insert low-res tokens
IM_START_TOKEN = "<im_start>" # nosec B105
LOW_RES_IMAGE_START_TOKEN = "<low_res_im_start>" # nosec B105
FRAME_START_TOKEN = "<frame_start>" # nosec B105
IM_END_TOKEN = "<im_end>" # nosec B105
FRAME_END_TOKEN = "<frame_end>" # nosec B105
IM_COL_TOKEN = "<im_col>" # nosec B105
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
@@ -224,7 +216,7 @@ class MolmoAct2Processor(ProcessorMixin):
input_ids = input_ids[None, :]
attention_mask = attention_mask[None, :]
B, S = input_ids.shape
B, S = input_ids.shape # noqa: N806
# Handle zero-length sequence
if S == 0:
@@ -364,7 +356,7 @@ class MolmoAct2Processor(ProcessorMixin):
assert num_videos in {0, 1}, "At most one video is supported for now"
video_grids_i = video_grids[index : index + num_videos]
metadata_i = video_metadata[index : index + num_videos]
for video_grid, metadata in zip(video_grids_i, metadata_i):
for video_grid, metadata in zip(video_grids_i, metadata_i, strict=False):
video_string = self.get_video_string(
video_grid,
metadata.timestamps,
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,25 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa
"""Video processor class for MolmoAct2"""
from functools import partial
import os
import warnings
from collections.abc import Callable
from contextlib import redirect_stdout
from functools import partial
from io import BytesIO
from urllib.parse import urlparse
from typing import Optional, Union
from collections.abc import Callable
import einops
import numpy as np
import requests
import einops
import torch
import torchvision.transforms
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
@@ -41,27 +37,24 @@ from transformers.image_utils import (
SizeDict,
validate_kwargs,
)
from transformers.video_utils import (
VideoInput,
is_valid_video,
make_batched_videos,
make_batched_metadata,
VideoMetadata,
)
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import (
TensorType,
is_av_available,
is_decord_available,
is_torchcodec_available,
is_yt_dlp_available,
TensorType,
logging,
to_numpy,
)
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import (
VideoInput,
VideoMetadata,
is_valid_video,
make_batched_metadata,
make_batched_videos,
)
logger = logging.get_logger(__name__)
@@ -102,8 +95,8 @@ def resize_image(
)(image)
resized = torch.clip(resized, 0.0, 1.0).to(dtype)
else:
assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(
image.dtype
assert image.dtype == torch.uint8, (
f"SigLIP expects float images or uint8 images, but got {image.dtype}"
)
in_min = 0.0
in_max = 255.0
@@ -548,9 +541,8 @@ def get_target_fps(
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if "uniform" in frame_sample_mode:
if num_frames_sampled_at_fps > max_frames:
break
if "uniform" in frame_sample_mode and num_frames_sampled_at_fps > max_frames:
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
@@ -779,13 +771,15 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
elif is_torchcodec_available():
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `torchcodec`."
"Falling back to `torchcodec`.",
stacklevel=2,
)
backend = "torchcodec"
else:
warnings.warn(
"`decord` is not installed and cannot be used to decode the video by default. "
"Falling back to `PyAV`."
"Falling back to `PyAV`.",
stacklevel=2,
)
backend = "pyav"
@@ -795,7 +789,8 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
*[
self.fetch_videos(x, sample_timestamps_fn=sample_timestamps_fn)
for x in video_url_or_urls
]
],
strict=False,
)
)
else:
@@ -821,7 +816,7 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
assert video_metadata[0].fps is not None, "FPS must be provided for video input"
sampled_videos = []
sampled_metadata = []
for video, metadata in zip(videos, video_metadata):
for video, metadata in zip(videos, video_metadata, strict=False):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
@@ -985,11 +980,11 @@ class MolmoAct2VideoProcessor(BaseVideoProcessor):
pixel_values_videos = np.concatenate(batch_crops, 0)
video_token_pooling = np.concatenate(batch_pooled_patches_idx, 0)
data = dict(
pixel_values_videos=pixel_values_videos,
video_token_pooling=video_token_pooling,
video_grids=video_grids,
)
data = {
"pixel_values_videos": pixel_values_videos,
"video_token_pooling": video_token_pooling,
"video_grids": video_grids,
}
return BatchFeature(data, tensor_type=return_tensors)
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""MolmoAct2 pre/post processing pipeline.
Builds the multimodal prompt (images, discretised state, task text),
tokenises it via the vendored MolmoAct2 processor, and handles quantile
normalisation with optional per-dimension gripper masking.
"""
from __future__ import annotations
import json
import os
import logging
import math
import re
from contextlib import suppress
from copy import deepcopy
@@ -27,7 +33,6 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from huggingface_hub import snapshot_download
from torch import Tensor
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
@@ -54,14 +59,71 @@ from lerobot.utils.constants import (
)
from lerobot.utils.import_utils import _scipy_available, _transformers_available, require_package
from .configuration_molmoact2 import MolmoAct2Config, infer_molmoact2_max_sequence_length
from .configuration_molmoact2 import MolmoAct2Config
from .modeling_molmoact2 import _hf_token, _resolve_checkpoint_location
logger = logging.getLogger(__name__)
MOLMOACT2_DEFAULT_NUM_IMAGES = 2
MOLMOACT2_IMAGE_TOKENS_PER_IMAGE = 196
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET = 80
MOLMOACT2_TASK_TOKEN_BUDGET = 32
MOLMOACT2_SEQUENCE_LENGTH_MARGIN = 32
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE = 64
MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS = 4
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP = 6
MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM = 0.95
def _round_up(value: int, multiple: int) -> int:
return int(math.ceil(value / multiple) * multiple)
def infer_molmoact2_max_sequence_length(
*,
num_images: int,
state_dim: int,
action_dim: int,
action_horizon: int,
include_discrete_action: bool,
) -> int:
"""Infer the padded text/image sequence cap from MolmoAct2's fixed token layout."""
if num_images < 1:
num_images = MOLMOACT2_DEFAULT_NUM_IMAGES
if state_dim < 0:
state_dim = 0
if action_dim < 1:
action_dim = 1
if action_horizon < 1:
action_horizon = 1
image_tokens = num_images * MOLMOACT2_IMAGE_TOKENS_PER_IMAGE
prompt_tokens = (
MOLMOACT2_FIXED_PROMPT_TOKEN_BUDGET
+ MOLMOACT2_TASK_TOKEN_BUDGET
+ state_dim
+ MOLMOACT2_SEQUENCE_LENGTH_MARGIN
)
action_tokens = 0
if include_discrete_action:
action_tokens_per_step = max(
MOLMOACT2_MIN_DISCRETE_ACTION_TOKENS_PER_STEP,
math.ceil(action_dim * MOLMOACT2_DISCRETE_ACTION_TOKENS_PER_DIM),
)
action_tokens = MOLMOACT2_DISCRETE_ACTION_WRAPPER_TOKENS + action_horizon * action_tokens_per_step
return _round_up(
image_tokens + prompt_tokens + action_tokens,
MOLMOACT2_SEQUENCE_LENGTH_MULTIPLE,
)
if TYPE_CHECKING or _transformers_available:
from transformers import Qwen2Tokenizer
from .hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
from .hf_model.processing_molmoact2 import MolmoAct2Processor
from .hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
from .molmoact2_hf_model.image_processing_molmoact2 import MolmoAct2ImageProcessor
from .molmoact2_hf_model.processing_molmoact2 import MolmoAct2Processor
from .molmoact2_hf_model.video_processing_molmoact2 import MolmoAct2VideoProcessor
else:
Qwen2Tokenizer = None
MolmoAct2ImageProcessor = None
@@ -69,7 +131,7 @@ else:
MolmoAct2VideoProcessor = None
if TYPE_CHECKING or (_transformers_available and _scipy_available):
from .hf_model.action_tokenizer import UniversalActionProcessor
from .molmoact2_hf_model.action_tokenizer import UniversalActionProcessor
else:
UniversalActionProcessor = None
@@ -97,32 +159,6 @@ _QUESTION_PREFIX_PATTERNS = tuple(
)
def _hf_token() -> str | None:
return os.environ.get("HF_TOKEN") or os.environ.get("HF_ACCESS_TOKEN")
def _resolve_checkpoint_location(
checkpoint_path: str,
*,
revision: str | None = None,
force_download: bool = False,
) -> str:
checkpoint_path = str(checkpoint_path or "").strip()
if not checkpoint_path:
raise ValueError("MolmoAct2 policy requires `checkpoint_path`.")
local_path = Path(checkpoint_path).expanduser()
if local_path.exists():
return str(local_path)
return snapshot_download(
repo_id=checkpoint_path,
repo_type="model",
revision=revision,
force_download=force_download,
ignore_patterns=["*.py", "*.pyc", "__pycache__/*"],
token=_hf_token(),
)
def _load_hf_norm_stats_for_tag(
checkpoint_path: str,
*,
@@ -1,2 +0,0 @@
# Local-only parity artifacts (regenerated via dump_original_n1_7.py); never committed.
*.npz
+11 -23
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for LeRobot's GR00T N1.7 policy forward and inference passes."""
"""Test script for LeRobot's Groot policy forward and inference passes."""
import gc
import os
@@ -41,20 +41,13 @@ pytestmark = pytest.mark.skipif(
)
# Define constants for dummy data (GR00T N1.7 native conventions).
# N1.7 internally uses a 40-step action chunk, 132-dim state/action, and 256px images
# (see GrootConfig.__post_init__). Use a chunk-sized action horizon so the dummy batch
# matches the model's native action space.
# Define constants for dummy data
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 40
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = auto_select_torch_device()
# GR00T N1.7 checkpoint (N1.5 is no longer supported). The N1.7-3B base model loads
# via GrootPolicy.from_pretrained with root-level sharded safetensors.
MODEL_PATH = "nvidia/GR00T-N1.7-3B"
# Valid N1.7 embodiment tag carried by the checkpoint metadata.
EMBODIMENT_TAG = "gr1_unified"
MODEL_PATH = "aractingi/bimanual-handover-groot-10k"
def cleanup_memory():
@@ -95,13 +88,13 @@ def instantiate_lerobot_groot(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot GR00T N1.7 policy with preprocessor and postprocessor."""
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = EMBODIMENT_TAG
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
@@ -109,7 +102,7 @@ def instantiate_lerobot_groot(
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag=EMBODIMENT_TAG,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
@@ -155,8 +148,8 @@ def create_dummy_data(device=DEVICE):
@require_cuda
def test_lerobot_groot_inference():
"""Test the inference pass (select_action) of LeRobot's GR00T N1.7 policy."""
print("Test: LeRobot GR00T N1.7 Inference Pass")
"""Test the inference pass (select_action) of LeRobot's Groot policy."""
print("Test: LeRobot Groot Inference Pass")
set_seed_all(42)
@@ -188,9 +181,9 @@ def test_lerobot_groot_inference():
@require_cuda
def test_lerobot_groot_forward_pass():
"""Test the forward pass of LeRobot's GR00T N1.7 policy."""
"""Test the forward pass of LeRobot's Groot policy."""
print("\n" + "=" * 50)
print("Test: LeRobot GR00T N1.7 Forward Pass (Training Mode)")
print("Test: LeRobot Groot Forward Pass (Training Mode)")
set_seed_all(42)
@@ -207,11 +200,6 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff
+408 -171
View File
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,194 +14,431 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
normalized flow-matching prediction before any action decoding) as NVIDIA's original
``gr00t`` package, given byte-identical pre-processed inputs and the same
flow-matching seed. The comparison is parametrized over every embodiment tag present
in the checkpoint.
To keep the comparison fair, the original outputs + the exact collated inputs are
produced once per embodiment in the original ``gr00t`` env via the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
to per-tag ``.npz`` files.
This test discovers those artifacts, replays the identical inputs through the LeRobot
model, and compares.
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default it looks for artifacts in
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
for the full run procedure.
"""
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import gc
import os
from pathlib import Path
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
pytest.importorskip("gr00t")
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="Requires a local GR00T N1.7 checkpoint + pre-generated artifacts; not for CI.",
reason="This test requires local Groot installation and is not meant for CI",
)
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
SEED = 42
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
from gr00t.data.dataset import ModalityConfig # noqa: E402
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
from gr00t.model.policy import Gr00tPolicy # noqa: E402
# Artifact filenames are original_n1_7_<embodiment_tag>.npz
_ARTIFACT_PREFIX = "original_n1_7_"
_ARTIFACT_SUFFIX = ".npz"
# GR1 humanoid dimensions (from pretrained model metadata)
# The actual GR1 robot has 44 dimensions for both state and action
# GR00TTransform will pad state to 64 and truncate action to 32
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = "cpu"
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
GR1_BODY_PARTS = {
"left_arm": 7,
"left_hand": 6,
"left_leg": 6,
"neck": 3,
"right_arm": 7,
"right_hand": 6,
"right_leg": 6,
"waist": 3,
}
def _artifact_dir() -> Path:
"""Directory holding the per-embodiment .npz artifacts.
Self-contained by default: a sibling ``artifacts/`` directory next to this test.
Override with ``GROOT_N1_7_PARITY_DIR`` (e.g. to point at a scratch location).
The directory is read-only here -- it is populated by ``utils/dump_original_n1_7.py``
run in the original gr00t environment; the test never creates it.
"""
env = os.environ.get("GROOT_N1_7_PARITY_DIR")
if env:
return Path(env)
return Path(__file__).resolve().parent / "artifacts"
def _discover_artifacts() -> list[tuple[str, Path]]:
"""Return [(embodiment_tag, npz_path), ...] for every dumped artifact."""
d = _artifact_dir()
if not d.is_dir():
return []
out = []
for p in sorted(d.glob(f"{_ARTIFACT_PREFIX}*{_ARTIFACT_SUFFIX}")):
tag = p.name[len(_ARTIFACT_PREFIX) : -len(_ARTIFACT_SUFFIX)]
out.append((tag, p))
return out
def _resolve_checkpoint() -> str:
env = os.environ.get("GROOT_N1_7_LIBERO_CKPT")
if env:
if not Path(env).exists():
pytest.skip(f"GROOT_N1_7_LIBERO_CKPT={env} does not exist")
return env
try:
from huggingface_hub import snapshot_download
root = snapshot_download(
"nvidia/GR00T-N1.7-LIBERO",
local_files_only=True,
allow_patterns=["libero_10/*"],
)
except Exception as exc: # noqa: BLE001
pytest.skip(f"GR00T N1.7 LIBERO checkpoint not available locally: {exc}")
ckpt = Path(root) / "libero_10"
if not (ckpt / "config.json").exists():
pytest.skip(f"GR00T N1.7 LIBERO checkpoint incomplete at {ckpt}")
return str(ckpt)
def _load_artifact(path: Path):
data = np.load(path, allow_pickle=True)
original_action = torch.from_numpy(data["action_pred"]).float()
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
inputs = {}
for key in data.files:
if not key.startswith("in::"):
continue
name = key[4:]
arr = data[key]
t = torch.from_numpy(np.asarray(arr))
declared = dtypes.get(key, "")
if "int" in declared or "long" in declared:
t = t.long()
inputs[name] = t
return original_action, inputs
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
"""Rebuild the nested model-input dict from dot-prefixed flat keys."""
nested: dict = {}
for dotted, value in inputs.items():
parts = dotted.split(".")
cur = nested
for p in parts[:-1]:
cur = cur.setdefault(p, {})
cur[parts[-1]] = value
return nested.get("inputs", nested)
@pytest.fixture(scope="module")
def lerobot_model():
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
ckpt = _resolve_checkpoint()
from lerobot.policies.groot.groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
ckpt,
tune_llm=False,
tune_visual=False,
tune_projector=False,
tune_diffusion_model=False,
tune_vlln=False,
transformers_loading_kwargs={"trust_remote_code": True},
)
# fp32 + SDPA on both sides: bf16 + differing attention kernels otherwise introduce
# ~1e-2 numerical noise unrelated to the implementations.
model.compute_dtype = "float32"
model.config.compute_dtype = model.compute_dtype
model.to(device=DEVICE, dtype=torch.float32)
model.eval()
return model
_ARTIFACTS = _discover_artifacts()
@pytest.mark.skipif(
not _ARTIFACTS,
reason=(
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
"env:\n .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py "
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
),
)
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
original_action, flat_inputs = _load_artifact(artifact)
model_inputs = _unflatten(flat_inputs)
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
torch.manual_seed(SEED)
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
with torch.inference_mode():
out = lerobot_model.get_action(model_inputs)
lerobot_action = out["action_pred"].float().cpu()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
t = min(original_action.shape[1], lerobot_action.shape[1])
d = min(original_action.shape[2], lerobot_action.shape[2])
original_action = original_action[:, :t, :d]
lerobot_action = lerobot_action[:, :t, :d]
diff = torch.abs(lerobot_action - original_action)
max_diff = diff.max().item()
print(
f"\n[{embodiment_tag}] shapes lerobot={tuple(lerobot_action.shape)} "
f"original={tuple(original_action.shape)} "
f"max|diff|={max_diff:.6e} mean|diff|={diff.mean().item():.6e}"
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
) -> tuple[
GrootPolicy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_HORIZON,
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_groot_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
)
assert torch.allclose(lerobot_action, original_action, atol=ATOL, rtol=RTOL), (
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
return (policy, preprocessor, postprocessor)
def instantiate_original_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
):
"""Instantiate original Groot policy from NVIDIA's implementation."""
from gr00t.data.transform.concat import ConcatTransform
from gr00t.data.transform.state_action import StateActionToTensor
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
from gr00t.model.transforms import GR00TTransform
video_keys = ["video.ego_view"]
state_keys = [
"state"
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
action_keys = [
"action.left_arm",
"action.right_arm",
"action.left_hand",
"action.right_hand",
"action.left_leg",
"action.right_leg",
"action.neck",
"action.waist",
]
language_keys = ["annotation.human.action.task_description"]
modality_config = {
"video": ModalityConfig(
delta_indices=[0], # Current frame only
modality_keys=video_keys,
),
"state": ModalityConfig(
delta_indices=[0],
modality_keys=state_keys,
),
"action": ModalityConfig(
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
modality_keys=action_keys,
),
"language": ModalityConfig(
delta_indices=[0],
modality_keys=language_keys,
),
}
modality_transform = ComposedModalityTransform(
transforms=[
VideoToTensor(apply_to=video_keys),
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
# State is already a single concatenated key, so no StateActionToTensor needed
# Convert action from numpy to tensor
StateActionToTensor(apply_to=action_keys),
# Concatenate only video and actions (state is already single key)
ConcatTransform(
video_concat_order=video_keys,
state_concat_order=[], # Empty:state is already single key
action_concat_order=action_keys,
),
GR00TTransform(
max_state_dim=64,
max_action_dim=32,
state_horizon=1,
action_horizon=DUMMY_ACTION_HORIZON,
training=False,
),
]
)
policy = Gr00tPolicy(
model_path=model_path,
embodiment_tag=EmbodimentTag.GR1,
modality_config=modality_config,
modality_transform=modality_transform,
device=DEVICE,
)
return policy, modality_config, modality_transform
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 2
prompt = "Pick up the red cube and place it in the bin"
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
batch = {
"observation.state": state,
"action": torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=device, # Action ground truth (for training)
),
"observation.images.ego_view": torch.rand(
batch_size,
3,
IMAGE_SIZE,
IMAGE_SIZE,
dtype=torch.float32,
device=device, # Images in [0, 1] range as expected by LeRobot
),
"task": [prompt for _ in range(batch_size)],
}
return batch
def convert_lerobot_to_original_format(batch, modality_config):
"""Convert LeRobot batch format to original Groot format.
The original Groot expects observations in this format:
{
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
"annotation.<annotation_type>": str or list[str]
}
"""
# Original Groot expects (T, H, W, C) format for images
# LeRobot has (B, C, H, W) format, so we need to convert
observation = {}
for img_key in ["ego_view"]:
lerobot_key = f"observation.images.{img_key}"
if lerobot_key in batch:
img = batch[lerobot_key]
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
# Convert [0, 1] to [0, 255] uint8 as expected by original
img_np = (img_np * 255).astype(np.uint8)
observation[f"video.{img_key}"] = img_np
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
if "observation.state" in batch:
state = batch["observation.state"]
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
observation["state"] = state_np
if "action" in batch:
action = batch["action"]
action_np = action.cpu().numpy()
start_idx = 0
for part_name, part_dim in GR1_BODY_PARTS.items():
end_idx = start_idx + part_dim
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
start_idx = end_idx
if "task" in batch:
task_list = batch["task"]
# GR00TTransform expects language with (B, T) shape for batched data
# Create a (B, T=1) array where each element is the string directly
bsz = len(task_list)
task_array = np.empty((bsz, 1), dtype=object)
for i in range(bsz):
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
observation["annotation.human.action.task_description"] = task_array
return observation
def test_groot_original_vs_lerobot_pretrained():
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
print("\n[LeRobot] Running inference...")
lerobot_policy.eval()
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
print("\n[Original] Running inference...")
original_policy.model.eval()
observation = convert_lerobot_to_original_format(batch, modality_config)
original_obs_transformed = modality_transform(deepcopy(observation))
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
original_model_output = original_policy.model.get_action(original_obs_transformed)
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
# Take first timestep
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
print("Action Comparison:")
diff = lerobot_actions - original_actions
abs_diff = torch.abs(diff)
for batch_idx in range(lerobot_actions.shape[0]):
print(f"\n{'=' * 60}")
print(f"Batch {batch_idx}")
print(f"{'=' * 60}")
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
print("-" * 60)
for action_idx in range(lerobot_actions.shape[1]):
lr_val = lerobot_actions[batch_idx, action_idx].item()
orig_val = original_actions[batch_idx, action_idx].item()
diff_val = abs(lr_val - orig_val)
sign = "+" if (lr_val - orig_val) > 0 else "-"
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
max_diff = abs_diff.max().item()
tolerance = 0.001
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
)
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation
cleanup_memory()
def test_groot_forward_pass_comparison():
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
print("Test: Forward Pass Comparison (Training Mode)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
lerobot_policy.eval()
original_policy.model.eval()
print("\n[LeRobot] Running forward pass...")
batch_lerobot = deepcopy(batch)
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
set_seed_all(42)
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
print(f" Loss: {lerobot_loss.item():.6f}")
print("\n[Original] Running forward pass...")
observation = convert_lerobot_to_original_format(batch, modality_config)
transformed_obs = modality_transform(observation)
if "action" not in transformed_obs:
action_for_forward = batch_lerobot_processed["action"]
action_mask_for_forward = batch_lerobot_processed["action_mask"]
# Match action horizon if needed
if action_for_forward.shape[1] != original_policy.model.action_horizon:
if action_for_forward.shape[1] < original_policy.model.action_horizon:
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
last_action = action_for_forward[:, -1:, :]
padding = last_action.repeat(1, pad_size, 1)
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
mask_padding = torch.zeros(
action_mask_for_forward.shape[0],
pad_size,
action_mask_for_forward.shape[2],
dtype=action_mask_for_forward.dtype,
device=action_mask_for_forward.device,
)
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
else:
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
action_mask_for_forward = action_mask_for_forward[
:, : original_policy.model.action_horizon, :
]
transformed_obs["action"] = action_for_forward
transformed_obs["action_mask"] = action_mask_for_forward
set_seed_all(42)
with torch.no_grad():
original_outputs = original_policy.model.forward(transformed_obs)
original_loss = original_outputs["loss"]
print(f" Loss: {original_loss.item():.6f}")
loss_diff = abs(lerobot_loss.item() - original_loss.item())
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
print("\nLoss Values:")
print(f" LeRobot: {lerobot_loss.item():.6f}")
print(f" Original: {original_loss.item():.6f}")
print(f" Absolute difference: {loss_diff:.6f}")
print(f" Relative difference: {loss_rel_diff:.2f}%")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation, transformed_obs
cleanup_memory()
-1
View File
@@ -1 +0,0 @@
"""Utilities shared by GR00T policy tests."""
@@ -1,198 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
"""Producer (run in the ORIGINAL gr00t env): dump original GR00T N1.7 outputs + inputs.
The original NVIDIA ``gr00t`` package pins ``transformers==4.57.3`` (py3.10) and its
model-config dataclasses are incompatible with the ``transformers==5.x`` that the
LeRobot GR00T N1.7 integration requires. The two implementations therefore cannot be
imported in the same Python process. To keep the parity comparison FAIR, we run the
original model in its native env here and serialize, PER EMBODIMENT TAG:
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
byte-identical tensors -- same image preprocessing, tokenization, normalization),
* the random seed used right before the flow-matching sampler,
* the raw ``action_pred`` tensor returned by ``model.get_action`` (normalized space,
before any per-implementation action decoding).
Inputs are built GENERICALLY from the checkpoint metadata (no per-tag hardcoding):
state keys + dims come from ``statistics.json``; video + language keys come from the
processor's per-embodiment modality configs. This lets us test many embodiment tags
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
``libero_sim``.
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
Usage:
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
--ckpt <path-to-GR00T-N1.7-LIBERO/libero_10> \
--out-dir tests/policies/groot/artifacts \
[--tags libero_sim,oxe_droid_relative_eef_relative_joint,...] \
[--device cuda] [--seed 42]
If --tags is omitted, every embodiment present in the checkpoint statistics is dumped.
"""
import argparse
import json
import os
from pathlib import Path
import numpy as np
import torch
IMAGE_SIZE = 256
BATCH_SIZE = 2
PROMPT = "pick up the black bowl and place it on the plate"
def load_statistics(ckpt: str) -> dict:
with open(os.path.join(ckpt, "statistics.json")) as f:
return json.load(f)
def make_observation(seed: int, video_keys, lang_key, state_spec):
"""Build a dummy observation dict generically from the embodiment metadata."""
rng = np.random.default_rng(seed)
video = {
k: rng.integers(0, 256, (BATCH_SIZE, 1, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)
for k in video_keys
}
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
# present-but-empty so the processor's state transform finds every expected key.
state = {
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
for k, dim in state_spec
}
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
return {"video": video, "state": state, "language": language}
def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_path):
from gr00t.data.types import MessageType
video_keys = modality_cfg["video"].modality_keys
lang_key = modality_cfg["language"].modality_keys[0]
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
policy.modality_configs = {
k: v for k, v in policy.processor.get_modality_configs()[tag].items() if k != "rl_info"
}
policy.language_key = policy.modality_configs["language"].modality_keys[0]
torch.manual_seed(args.seed)
np.random.seed(args.seed)
unbatched = policy._unbatch_observation(observation)
processed = []
for obs in unbatched:
vla = policy._to_vla_step_data(obs)
processed.append(policy.processor([{"type": MessageType.EPISODE_STEP.value, "content": vla}]))
collated = policy.collate_fn(processed)
def to_dev(x):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(args.device, torch.float32)
if isinstance(x, torch.Tensor):
return x.to(args.device)
if isinstance(x, dict):
return {k: to_dev(v) for k, v in x.items()}
return x
collated = {k: to_dev(v) for k, v in collated.items()}
torch.manual_seed(args.seed)
with torch.inference_mode():
out = fair_model.get_action(**collated)
action_pred = out["action_pred"].float().cpu().numpy()
flat, meta = {}, {}
def flatten(prefix, obj):
if isinstance(obj, torch.Tensor):
arr = obj.float().cpu().numpy() if torch.is_floating_point(obj) else obj.cpu().numpy()
flat[f"in::{prefix}"] = arr
meta[f"in::{prefix}"] = str(obj.dtype)
elif isinstance(obj, dict):
for k, v in obj.items():
flatten(f"{prefix}.{k}" if prefix else k, v)
elif isinstance(obj, (list, tuple)):
flat[f"in::{prefix}"] = np.array(obj, dtype=object)
else:
flat[f"in::{prefix}"] = np.array(obj)
flatten("", collated)
out_path.parent.mkdir(parents=True, exist_ok=True)
np.savez(
out_path,
action_pred=action_pred,
seed=np.array(args.seed),
device=np.array(args.device),
embodiment_tag=np.array(tag),
meta_keys=np.array(list(meta.keys()), dtype=object),
meta_dtypes=np.array(list(meta.values()), dtype=object),
**flat,
)
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--out-dir", required=True, help="directory for per-tag .npz files")
ap.add_argument("--tags", default="", help="comma-separated embodiment tags (default: all in stats)")
ap.add_argument("--device", default="cuda")
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
from gr00t.policy.gr00t_policy import Gr00tPolicy
from transformers import AutoConfig, AutoModel
stats = load_statistics(args.ckpt)
requested = [t.strip() for t in args.tags.split(",") if t.strip()] or list(stats.keys())
# Load the policy once (for its processor/preprocessing) on any valid tag.
bootstrap_tag = "libero_sim" if "libero_sim" in stats else requested[0]
policy = Gr00tPolicy(embodiment_tag=bootstrap_tag, model_path=args.ckpt, device=args.device)
all_modality = policy.processor.get_modality_configs()
# Load a FAIR model (SDPA + fp32) once and reuse across tags. Otherwise the
# original checkpoint default (flash_attention_2 + bf16) introduces kernel/rounding
# noise vs the LeRobot env (which has no flash_attn and runs SDPA).
cfg = AutoConfig.from_pretrained(args.ckpt, trust_remote_code=True)
cfg.use_flash_attention = False
cfg.load_bf16 = False
fair_model = AutoModel.from_pretrained(args.ckpt, config=cfg, trust_remote_code=True)
fair_model.to(device=args.device, dtype=torch.float32)
fair_model.eval()
out_dir = Path(args.out_dir)
done, skipped = [], []
for tag in requested:
if tag not in stats or tag not in all_modality:
print(f"[skip] {tag}: not present in checkpoint statistics/modality configs")
skipped.append(tag)
continue
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
try:
dump_one_tag(
policy, fair_model, tag, all_modality[tag], state_spec, args,
out_dir / f"original_n1_7_{tag}.npz",
)
done.append(tag)
except Exception as exc: # noqa: BLE001
print(f"[fail] {tag}: {type(exc).__name__}: {exc}")
skipped.append(tag)
print(f"\nDumped {len(done)} tags: {done}")
if skipped:
print(f"Skipped/failed {len(skipped)} tags: {skipped}")
if __name__ == "__main__":
main()
+35 -48
View File
@@ -1,5 +1,3 @@
#!/usr/bin/env python
# Copyright 2026 The Allen Institute for Artificial Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -35,16 +33,16 @@ pytest.importorskip("scipy")
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies import get_policy_class, make_policy_config
from lerobot.policies.molmoact2 import (
configuration_molmoact2 as molmoact2_config,
modeling_molmoact2 as molmoact2_modeling,
processor_molmoact2 as molmoact2_processor,
)
from lerobot.policies.molmoact2.configuration_molmoact2 import (
MolmoAct2Config,
MolmoAct2CosineDecayWithWarmupSchedulerConfig,
infer_molmoact2_max_sequence_length,
from lerobot.policies.molmoact2.configuration_molmoact2 import MolmoAct2Config
from lerobot.policies.molmoact2.modeling_molmoact2 import (
MolmoAct2Policy,
_apply_action_chunk_padding_mask,
_apply_action_dim_padding_mask,
_combine_rollout_seeds,
)
from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy
from lerobot.policies.molmoact2.processor_molmoact2 import (
MolmoAct2ClampNormalizedProcessorStep,
MolmoAct2MaskedNormalizerProcessorStep,
@@ -53,6 +51,7 @@ from lerobot.policies.molmoact2.processor_molmoact2 import (
_add_gripper_masks_to_stats,
_build_discrete_state_string,
_normalize_question_text,
infer_molmoact2_max_sequence_length,
make_molmoact2_pre_post_processors,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@@ -71,34 +70,38 @@ def test_molmoact2_policy_registration():
assert cfg.per_episode_seed is False
assert cfg.eval_seed is None
assert cfg.normalize_language is True
assert cfg.get_scheduler_preset().num_decay_steps is None
assert cfg.get_scheduler_preset().num_decay_steps == 100_000
assert cfg.action_delta_indices == list(range(cfg.chunk_size))
assert get_policy_class("molmoact2") is MolmoAct2Policy
def test_molmoact2_checkpoint_download_ignores_remote_python(monkeypatch):
import huggingface_hub
download_kwargs = {}
def fake_snapshot_download(**kwargs):
download_kwargs.update(kwargs)
return "/tmp/downloaded-molmoact2"
monkeypatch.setattr(molmoact2_config, "snapshot_download", fake_snapshot_download)
monkeypatch.setattr(huggingface_hub, "snapshot_download", fake_snapshot_download)
checkpoint_location = molmoact2_config._resolve_checkpoint_location("allenai/MolmoAct2")
checkpoint_location = molmoact2_modeling._resolve_checkpoint_location("allenai/MolmoAct2")
assert checkpoint_location == "/tmp/downloaded-molmoact2"
assert download_kwargs["ignore_patterns"] == ["*.py", "*.pyc", "__pycache__/*"]
def test_molmoact2_scheduler_decay_steps_auto_match_training_steps():
def test_molmoact2_scheduler_auto_scales_to_training_steps():
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
param = torch.nn.Parameter(torch.ones(()))
optimizer = torch.optim.AdamW([param], lr=0.001)
config = MolmoAct2CosineDecayWithWarmupSchedulerConfig(
config = CosineDecayWithWarmupSchedulerConfig(
peak_lr=0.01,
decay_lr=0.001,
num_warmup_steps=10,
num_decay_steps=None,
num_decay_steps=100_000,
)
scheduler = config.build(optimizer, num_training_steps=100)
@@ -123,9 +126,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_first = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
)
expected_first = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
assert torch.allclose(torch.rand(4, generator=first), torch.rand(4, generator=expected_first))
policy.reset()
@@ -134,9 +135,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_second = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1003, batch_size=3)
)
expected_second = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1003, batch_size=3))
assert torch.allclose(torch.rand(4, generator=second), torch.rand(4, generator=expected_second))
policy.reset()
@@ -145,9 +144,7 @@ def test_molmoact2_rollout_generator_uses_eval_seed_per_task():
batch_size=3,
device=torch.device("cpu"),
)
expected_new_task = torch.Generator().manual_seed(
MolmoAct2Policy._combine_rollout_seeds(first_seed=1000, batch_size=3)
)
expected_new_task = torch.Generator().manual_seed(_combine_rollout_seeds(first_seed=1000, batch_size=3))
assert torch.allclose(torch.rand(4, generator=new_task), torch.rand(4, generator=expected_new_task))
@@ -537,36 +534,26 @@ def test_train_action_expert_only_requires_continuous_action_mode():
def test_molmoact2_sequence_length_is_inferred_from_fixed_token_budget():
cfg = MolmoAct2Config(
action_mode="both",
chunk_size=10,
n_action_steps=10,
image_keys=["observation.images.image", "observation.images.wrist_image"],
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,))},
)
assert cfg.max_sequence_length is None
assert cfg.inferred_max_sequence_length() == 640
assert cfg.inferred_max_sequence_length(include_discrete_action=False) == 576
assert (
infer_molmoact2_max_sequence_length(
num_images=2,
state_dim=8,
action_dim=7,
action_horizon=30,
include_discrete_action=True,
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=True
)
== 640
)
assert (
infer_molmoact2_max_sequence_length(
num_images=2, state_dim=8, action_dim=7, action_horizon=10, include_discrete_action=False
)
== 576
)
assert (
infer_molmoact2_max_sequence_length(
num_images=2, state_dim=8, action_dim=7, action_horizon=30, include_discrete_action=True
)
== 768
)
def test_molmoact2_sequence_length_override_is_preserved():
cfg = MolmoAct2Config(max_sequence_length=1024)
assert cfg.inferred_max_sequence_length(num_images=2, state_dim=8, action_dim=7) == 1024
def test_train_action_expert_only_freezes_non_action_expert_params():
class DummyBackbone(torch.nn.Module):
def __init__(self):
@@ -963,7 +950,7 @@ def test_action_dim_padding_loss_reduces_like_old_trainer():
]
)
reduced = MolmoAct2Policy._apply_action_dim_padding_mask(loss, action_dim_is_pad)
reduced = _apply_action_dim_padding_mask(loss, action_dim_is_pad)
expected = torch.stack(
[
@@ -979,7 +966,7 @@ def test_action_chunk_padding_keeps_old_mean_denominator():
loss = torch.ones(1, 2, 4, 3)
action_horizon_is_pad = torch.tensor([[False, False, True, True]])
masked = MolmoAct2Policy._apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
masked = _apply_action_chunk_padding_mask(loss, action_horizon_is_pad)
assert masked.mean().item() == 0.5