Compare commits

...

14 Commits

Author SHA1 Message Date
acwrenn53 0509ea05df Merge pull request #10 from acwrenn53/nvidia-gr00t-n17-lerobot-cleanup
Remove GR00T N1.5 support and fix LIBERO gripper action transform
2026-06-05 12:15:10 -07:00
Andrew Wrenn de1a9e5ad9 Reconnect GR00T relative action processors 2026-06-05 09:31:04 -07:00
groot-validation 6803439f22 groot: auto-enable LIBERO gripper action transform for libero_sim
GR00T N1.7 emits gripper in [0,1] but LIBERO expects [-1,1]. The decode
transform existed but was never auto-enabled for embodiment_tag=libero_sim,
so the policy scored 0% on LIBERO eval. Auto-set it in __post_init__ (still
overridable). LIBERO Spatial eval: 0% -> 98%.
2026-06-05 00:56:11 +00:00
nv-sachdevkartik 90d1e70da2 removed remaining N1.5 traces 2026-06-05 00:11:37 +00:00
nv-sachdevkartik a35ac22afd removed n1.5 dependency 2026-06-04 22:14:07 +00:00
Kartik fd7fed08e2 Merge branch 'huggingface:main' into nvidia-gr00t-n17-lerobot 2026-06-04 23:41:09 +02:00
acwrenn53 0c3cc4c9d6 Merge pull request #6 from acwrenn53/nvidia-gr00t-n17-lerobot-rtc-2
Nvidia gr00t n17 lerobot rtc 2
2026-06-03 16:10:49 -07:00
Andrew Wrenn 6caeac9d07 Ignore padded GR00T N1.7 RTC prefix rows 2026-06-03 14:04:31 -07:00
Andrew Wrenn 1d6810b814 Trim GR00T N1.7 RTC chunks to valid horizon 2026-06-03 13:51:35 -07:00
Andrew Wrenn de9af57475 Fix GR00T N1.7 RTC action decoding 2026-06-03 13:43:13 -07:00
Andrew Wrenn 364750ada2 Allow Groot fake RTC chunk prefetch 2026-06-02 14:20:00 -07:00
Andrew Wrenn 342d223706 Restore GR00T Flash Attention install guidance 2026-06-02 13:26:08 -07:00
Andrew Wrenn e3b203e5a7 Move Groot processor compatibility into Groot loader 2026-06-02 13:19:12 -07:00
Andrew Wrenn b568c41355 Add GR00T N1.7 support
Add GR00T N1.7 policy configuration, checkpoint compatibility, processor parity, LIBERO documentation, and focused tests.

Co-authored-by: Ryan Halabi <ryhalabi@nvidia.com>
2026-06-01 08:57:04 -07:00
17 changed files with 5069 additions and 1280 deletions
+1 -1
View File
@@ -105,7 +105,7 @@ lerobot-train \
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+1 -1
View File
@@ -68,7 +68,7 @@
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
title: NVIDIA GR00T
- local: xvla
title: X-VLA
- local: multi_task_dit
+1 -1
View File
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
- [SmolVLA](./smolvla)
- [Pi0.5](./pi05)
- [GR00T N1.5](./groot)
- [GR00T N1.7](./groot)
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
+76 -30
View File
@@ -1,16 +1,16 @@
# GR00T N1.5 Policy
# GR00T Policy
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
This document outlines the specifics of its integration and usage within the LeRobot framework.
LeRobot integrates GR00T N1.7 through the `groot` policy type.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
@@ -28,33 +28,46 @@ This approach allows the model to be highly adaptable through post-training for
## Installation Requirements
As of today, GR00T N1.5 requires flash attention for it's internal working.
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
```
3. Install and verify Flash Attention:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
3. Install LeRobot by running:
4. Install LeRobot with the GR00T extra:
```bash
pip install lerobot[groot]
pip install "lerobot[groot]"
```
For a source checkout, use the same order, then install the local package with:
```bash
pip install -e ".[groot]"
```
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
## Usage
To use GR00T in your LeRobot configuration, specify the policy type as:
To use GR00T N1.7:
```python
policy.type=groot
```bash
--policy.type=groot \
--policy.model_version=n1.7
```
## Training
@@ -87,21 +100,54 @@ accelerate launch \
## Performance Results
### Libero Benchmark Results
### LIBERO Benchmark Results
> [!NOTE]
> Follow our instructions for Libero usage: [Libero](./libero)
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
### GR00T N1.7 LIBERO Checkpoints
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
| Suite | Checkpoint subdirectory |
| -------------- | ----------------------- |
| LIBERO Spatial | `libero_spatial` |
| LIBERO Object | `libero_object` |
| LIBERO Goal | `libero_goal` |
| LIBERO 10 | `libero_10` |
Preliminary LeRobot integration results:
| Suite | Status | Success rate | n_episodes |
| -------------- | ------ | -----------: | ---------: |
| LIBERO Spatial | ✓ | ~95% | XX |
| LIBERO Object | ✓ | XX% | XX |
| LIBERO Goal | ✓ | XX% | XX |
| LIBERO 10 | ✓ | XX% | XX |
| **Average** | ✓ | **XX%** | **XX** |
Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
lerobot-eval \
--policy.type=groot \
--policy.model_version=n1.7 \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \
--eval.n_episodes=50
```
Use `eval.n_episodes >= 50` per suite when reporting success rates.
### Evaluate in your hardware setup
@@ -131,4 +177,4 @@ lerobot-rollout\
## License
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
+4 -1
View File
@@ -24,4 +24,7 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
Hugging Face Models:
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
+14 -18
View File
@@ -280,26 +280,22 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return make_groot_pre_post_processors_from_pretrained(
config=policy_cfg,
pretrained_path=pretrained_path,
dataset_stats=kwargs.get("dataset_stats"),
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
preprocessor_config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
postprocessor_config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
)
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
+9 -1
View File
@@ -18,4 +18,12 @@ from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
def __getattr__(name: str):
if name in {"GR00TN17", "GR00TN17Config"}:
from .groot_n1_7 import GR00TN17, GR00TN17Config
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -181,8 +181,7 @@ class BasicTransformerBlock(nn.Module):
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
@@ -318,6 +317,71 @@ class DiT(ModelMixin, ConfigMixin):
return self.proj_out_2(hidden_states)
class AlternateVLDiT(DiT):
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
super().__init__(*args, **kwargs)
self.attend_text_every_n_blocks = attend_text_every_n_blocks
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
image_mask: torch.Tensor | None = None,
backbone_attention_mask: torch.Tensor | None = None,
):
if image_mask is None:
raise ValueError("image_mask is required for AlternateVLDiT.")
if backbone_attention_mask is None:
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
image_attention_mask = image_mask & backbone_attention_mask
non_image_attention_mask = (~image_mask) & backbone_attention_mask
all_hidden_states = [hidden_states]
if not self.config.interleave_self_attention:
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
curr_encoder_attention_mask = (
non_image_attention_mask
if idx % (2 * self.attend_text_every_n_blocks) == 0
else image_attention_mask
)
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=curr_encoder_attention_mask,
temb=temb,
)
all_hidden_states.append(hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@@ -110,7 +110,7 @@ class MultiEmbodimentActionEncoder(nn.Module):
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
"""Flow-matching action head used by GR00T backbones."""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
@@ -14,12 +14,294 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
"n1_7": GROOT_N1_7,
"n1d7": GROOT_N1_7,
"n17": GROOT_N1_7,
"1.7": GROOT_N1_7,
}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
}
def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
return normalized
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
if transform is None:
return None
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
supported = ", ".join(
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
)
raise ValueError(
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
f"Supported transforms: none, {supported}."
)
return normalized
def infer_groot_model_version(model_path: str | None) -> str | None:
if not model_path:
return None
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
# N1.5 is unsupported, but it must still be recognised here to fail loudly
# rather than silently treating an N1.5 checkpoint as N1.7.
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
return None
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
if model_path is None:
return False
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return False
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return False
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if "libero_sim" in modality_configs:
return "libero_sim"
if len(modality_configs) == 1:
return next(iter(modality_configs))
return None
def infer_groot_n1_7_action_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
return None
modality_configs = processor_kwargs.get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag is None:
return None
embodiment_config = modality_configs.get(embodiment_tag, {})
if not isinstance(embodiment_config, dict):
return None
action_config = embodiment_config.get("action", {})
if not isinstance(action_config, dict):
return None
delta_indices = action_config.get("delta_indices", [])
if not isinstance(delta_indices, list):
return None
return len(delta_indices) or None
def infer_groot_n1_7_action_execution_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
if action_horizon is None:
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag == "libero_sim":
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
# actions. Keeping that execution cadence avoids stale open-loop chunks.
return min(action_horizon, 8)
return action_horizon
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
model_path = Path(model_name).expanduser()
if model_path.exists():
return str(model_path)
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
return str(cached_snapshot) if cached_snapshot is not None else model_name
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
required_files = (
"config.json",
"tokenizer_config.json",
"preprocessor_config.json",
"video_preprocessor_config.json",
)
for hub_cache in _candidate_hf_hub_caches(cache_dir):
repo_cache = hub_cache / repo_cache_name
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.is_dir():
continue
candidates: list[Path] = []
ref_path = repo_cache / "refs" / "main"
try:
ref = ref_path.read_text().strip()
except OSError:
ref = ""
if ref:
candidates.append(snapshots_dir / ref)
candidates.extend(
sorted(
(path for path in snapshots_dir.iterdir() if path.is_dir()),
key=lambda path: path.stat().st_mtime,
reverse=True,
)
)
seen: set[Path] = set()
for snapshot in candidates:
if snapshot in seen:
continue
seen.add(snapshot)
if all((snapshot / filename).exists() for filename in required_files):
return snapshot
return None
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
candidates: list[Path] = []
if cache_dir is not None:
cache_path = Path(cache_dir).expanduser()
candidates.append(cache_path)
candidates.append(cache_path / "hub")
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
if hub_cache:
candidates.append(Path(hub_cache).expanduser())
hf_home = os.environ.get("HF_HOME")
if hf_home:
candidates.append(Path(hf_home).expanduser() / "hub")
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
deduped: list[Path] = []
seen: set[Path] = set()
for candidate in candidates:
resolved = candidate.resolve() if candidate.exists() else candidate
if resolved not in seen:
seen.add(resolved)
deduped.append(candidate)
return deduped
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return None
if not config_path.exists():
return None
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
return _infer_groot_model_version_from_config(config)
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
try:
return normalize_groot_model_version(model_version)
except ValueError:
return None
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
for candidate in candidates:
if not isinstance(candidate, str):
continue
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
return None
@PreTrainedConfig.register_subclass("groot")
@dataclass
@@ -52,11 +334,17 @@ class GrootConfig(PreTrainedConfig):
# Groot-specific model parameters (from groot_finetune_script.py)
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Path or HuggingFace model ID for the base Groot model
base_model_path: str | None = None
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -117,6 +405,38 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = GROOT_N1_7_BASE_MODEL
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
# wrapper applies this conversion; mirror it here so eval on the
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
raise ValueError(
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
super().__post_init__()
if self.n_action_steps > self.chunk_size:
@@ -192,7 +512,8 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
return list(range(min(self.chunk_size, 16)))
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def reward_delta_indices(self) -> None:
-380
View File
@@ -1,380 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
from .action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from .utils import ensure_eagle_cache_ready
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if ACTION in inputs:
action = inputs[ACTION]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model
+962
View File
@@ -0,0 +1,962 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import json
import logging
from contextlib import suppress
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
try:
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
except ImportError:
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
logger = logging.getLogger(__name__)
def _copy_default(value: Any) -> Any:
return deepcopy(value)
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
"model_name": "nvidia/Cosmos-Reason2-2B",
"backbone_model_type": "qwen",
"model_revision": None,
"tune_top_llm_layers": 0,
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 12,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
"color_jitter_params": None,
"use_albumentations_transforms": True,
"extra_augmentation_config": None,
"formalize_language": True,
"apply_sincos_state_encoding": False,
"use_percentiles": True,
"use_relative_action": False,
"max_state_dim": 132,
"max_action_dim": 132,
"action_horizon": 40,
"hidden_size": 1024,
"input_embedding_dim": 1536,
"state_history_length": 1,
"add_pos_embed": True,
"attn_dropout": 0.2,
"use_vlln": True,
"max_seq_len": 1024,
"use_alternate_vl_dit": True,
"attend_text_every_n_blocks": 2,
"diffusion_model_cfg": {
"positional_embeddings": None,
"num_layers": 32,
"num_attention_heads": 32,
"attention_head_dim": 48,
"norm_type": "ada_norm",
"dropout": 0.2,
"final_dropout": True,
"output_dim": 1024,
"interleave_self_attention": True,
},
"vl_self_attention_cfg": {
"positional_embeddings": None,
"num_layers": 4,
"num_attention_heads": 32,
"attention_head_dim": 64,
"dropout": 0.2,
"final_dropout": True,
},
"num_inference_timesteps": 4,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
"tune_projector": True,
"tune_diffusion_model": True,
"tune_vlln": True,
"state_dropout_prob": 0.2,
"exclude_state": False,
"use_mean_std": False,
"max_num_embodiments": 32,
"rtc_ramp_rate": 6.0,
}
class GR00TN17Config(PretrainedConfig):
"""Configuration for NVIDIA GR00T N1.7.
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
flow-matching action head. This mirrors the public N1.7 checkpoint config
while keeping it local to LeRobot and independent from the external
Isaac-GR00T ``gr00t`` Python package.
"""
model_type = "Gr00tN1d7"
_defaults = GR00T_N1_7_DEFAULTS
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in GR00T_N1_7_DEFAULTS.items():
setattr(self, key, _copy_default(kwargs.pop(key, value)))
for key, value in kwargs.items():
setattr(self, key, value)
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
cfg = self.to_dict()
if not exclude_augment:
return cfg
exclude_keys = {
"random_rotation_angle",
"color_jitter_params",
"use_albumentations_transforms",
"formalize_language",
"image_crop_size",
"image_target_size",
"shortest_image_edge",
"crop_fraction",
}
return {k: v for k, v in cfg.items() if k not in exclude_keys}
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
class CategorySpecificLinear(nn.Module):
"""Linear layer with category-specific weights for multi-embodiment support."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
super().__init__()
self.num_categories = num_categories
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
"""Two-layer MLP with category-specific weights."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
The frequency scalar is intentionally created on CPU and then broadcast with
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
embedding and avoids tiny dtype/device construction differences in parity
tests.
"""
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class MultiEmbodimentActionEncoder(nn.Module):
"""Action encoder with category-specific projections and sinusoidal time encoding."""
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
super().__init__()
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
batch_size, horizon, _ = actions.shape
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
raise ValueError("Expected `timesteps` to have shape (B,).")
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
action_emb = self.W1(actions, cat_ids)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
return self.W3(x, cat_ids)
class Qwen3Backbone(nn.Module):
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
The public checkpoint stores the action head in the GR00T checkpoint but
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
keeps the nested HF module layout compatible across transformer versions
and exposes the hidden states consumed by the action head.
"""
def __init__(
self,
model_name: str = "nvidia/Cosmos-Reason2-2B",
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
tune_top_llm_layers: int = 0,
trainable_params_fp32: bool = False,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_pretrained_weights: bool = True,
):
if Qwen3VLForConditionalGeneration is None:
raise ImportError(
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
"or use a transformers version that provides Qwen3-VL support."
)
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
extra_kwargs: dict[str, Any] = {}
if use_flash_attention:
try:
import flash_attn # noqa: F401
extra_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
extra_kwargs["attn_implementation"] = "sdpa"
if load_bf16:
extra_kwargs["torch_dtype"] = torch.bfloat16
if load_pretrained_weights:
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
**extra_kwargs,
**transformers_loading_kwargs,
).eval()
else:
self.model = self._from_backbone_config(
model_name=model_name,
model_kwargs=extra_kwargs,
config_kwargs=transformers_loading_kwargs,
).eval()
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
if load_bf16 and trainable_params_fp32:
for parameter in self.parameters():
if parameter.requires_grad:
parameter.data = parameter.data.to(torch.float32)
def set_trainable_parameters(
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
) -> None:
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_llm:
self.language_model.requires_grad_(False)
if not tune_visual:
self.visual.requires_grad_(False)
if tune_top_llm_layers > 0:
for layer in self.language_model.layers[-tune_top_llm_layers:]:
for parameter in layer.parameters():
parameter.requires_grad = True
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if self.language_model and not self.tune_llm:
self.language_model.eval()
if self.visual and not self.tune_visual:
self.visual.eval()
@property
def language_model(self) -> nn.Module:
return getattr(self.model, "model", self.model).language_model
@property
def visual(self) -> nn.Module:
return getattr(self.model, "model", self.model).visual
def _from_backbone_config(
self,
*,
model_name: str,
model_kwargs: dict[str, Any],
config_kwargs: dict[str, Any],
) -> nn.Module:
if _is_cosmos_reason2_backbone(model_name):
backbone_config = _cosmos_reason2_qwen3_vl_config()
else:
if AutoConfig is None:
raise ImportError(
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
if "mm_token_type_ids" in model_input:
return
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
return
input_ids = model_input.get("input_ids")
if input_ids is None:
return
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None:
mm_token_type_ids[input_ids == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[input_ids == video_token_id] = 2
model_input["mm_token_type_ids"] = mm_token_type_ids
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
drops text position ids before calling text-layer flash attention. GR00T
N1.7 was aligned against the older Transformers path, where a fourth text
position row is forwarded alongside the temporal/height/width rows. Adding
the row here preserves the newer multimodal position computation while
keeping flash attention on the legacy code path.
"""
if "position_ids" in model_input:
return
qwen3_model = getattr(self.model, "model", self.model)
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
if compute_3d_position_ids is None:
return
position_ids = compute_3d_position_ids(
input_ids=model_input.get("input_ids"),
image_grid_thw=model_input.get("image_grid_thw"),
video_grid_thw=model_input.get("video_grid_thw"),
inputs_embeds=None,
attention_mask=model_input.get("attention_mask"),
past_key_values=None,
mm_token_type_ids=model_input.get("mm_token_type_ids"),
)
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
model_input["position_ids"] = position_ids
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
Newer releases expose the post-final-norm tensor there instead. Capturing
the last decoder layer output directly keeps the N1.7 action head input
stable across Transformers versions.
"""
captured: dict[str, torch.Tensor] = {}
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
if isinstance(output, torch.Tensor):
captured["features"] = output
elif isinstance(output, (tuple, list)) and output:
captured["features"] = output[0]
elif hasattr(output, "last_hidden_state"):
captured["features"] = output.last_hidden_state
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
try:
outputs = self.model(**model_input, output_hidden_states=True)
finally:
hook.remove()
return captured.get("features", outputs.hidden_states[-1])
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
model_input = {key: vl_input[key] for key in keys_to_use}
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
self._ensure_mm_token_type_ids(model_input)
self._ensure_legacy_qwen3_position_ids(model_input)
features = self._last_decoder_layer_output(model_input)
image_mask = model_input["input_ids"] == self.model.config.image_token_id
attention_mask = model_input["attention_mask"] == 1
return BatchFeature(
data={
"backbone_features": features,
"backbone_attention_mask": attention_mask,
"image_mask": image_mask,
}
)
class GR00TN17ActionHead(nn.Module):
supports_gradient_checkpointing = True
def __init__(self, config: GR00TN17Config):
require_package("diffusers", extra="groot")
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
if config.use_alternate_vl_dit:
self.model = AlternateVLDiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
)
else:
self.model = DiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
)
self.action_dim = config.max_action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim * config.state_history_length,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
else:
self.vl_self_attention = nn.Identity()
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.state_dropout_prob = config.state_dropout_prob
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
def set_trainable_parameters(
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
) -> None:
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vlln = tune_vlln
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vlln:
self.vlln.requires_grad_(False)
self.vl_self_attention.requires_grad_(False)
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vlln:
self.vlln.eval()
self.vl_self_attention.eval()
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if self._beta_dist is None:
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = self.vlln(backbone_output["backbone_features"])
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
vl_embeds = backbone_output.backbone_features
device = vl_embeds.device
embodiment_id = action_input.embodiment_id
if action_input.state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = action_input.state.view(action_input.state.shape[0], 1, -1)
state_features = self.state_encoder(state, embodiment_id)
if self.training and self.state_dropout_prob > 0:
do_dropout = (
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
)
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None]
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
data={
"loss": loss,
"action_loss": action_loss,
"action_mask": action_mask,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
state = action_input.state
if state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = state.view(state.shape[0], 1, -1)
state_features = self.state_encoder(state, action_input.embodiment_id)
return BatchFeature(
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
)
@torch.no_grad()
def get_action_with_features(
self,
backbone_features: torch.Tensor,
state_features: torch.Tensor,
embodiment_id: torch.Tensor,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
vl_embeds = backbone_features
batch_size = vl_embeds.shape[0]
device = vl_embeds.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=vl_embeds.dtype,
device=device,
)
dt = 1.0 / self.num_inference_timesteps
vel_strength = torch.ones_like(actions)
if "action" in action_input:
if options is None:
raise ValueError("RTC options are required when action is provided to get_action.")
action_horizon_before_padding = options["action_horizon"]
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
:,
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
:,
]
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
ramp = ramp / ramp[-1].clamp_min(1e-8)
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
None, :, None
].to(device)
for t_step in range(self.num_inference_timesteps):
t_cont = t_step / float(self.num_inference_timesteps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
@torch.no_grad()
def get_action(
self,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
features = self._encode_features(backbone_output, action_input)
return self.get_action_with_features(
backbone_features=features.backbone_features,
state_features=features.state_features,
embodiment_id=action_input.embodiment_id,
backbone_output=backbone_output,
action_input=action_input,
options=options,
)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
if Qwen3VLConfig is None:
raise ImportError(
"Qwen3VLConfig is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
return Qwen3VLConfig(
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=True,
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
)
def get_backbone_cls(config: GR00TN17Config):
if (
config.backbone_model_type == "qwen"
or "nvidia/Cosmos-Reason2" in config.model_name
or "Qwen/Qwen3-VL" in config.model_name
):
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
class GR00TN17(PreTrainedModel):
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
config_class = GR00TN17Config
supports_gradient_checkpointing = True
def __init__(
self,
config: GR00TN17Config,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_backbone_weights: bool = True,
):
super().__init__(config)
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
self.config = config
backbone_cls = get_backbone_cls(config)
self.backbone = backbone_cls(
model_name=config.model_name,
tune_llm=config.tune_llm,
tune_visual=config.tune_visual,
select_layer=config.select_layer,
reproject_vision=config.reproject_vision,
use_flash_attention=config.use_flash_attention,
load_bf16=config.load_bf16,
tune_top_llm_layers=config.tune_top_llm_layers,
trainable_params_fp32=config.backbone_trainable_params_fp32,
transformers_loading_kwargs=transformers_loading_kwargs,
load_pretrained_weights=load_backbone_weights,
)
self.action_head = GR00TN17ActionHead(config)
self.post_init()
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
global tree
if tree is None:
require_package("dm-tree", extra="groot", import_name="tree")
tree = importlib.import_module("tree")
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_dtype(x):
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
return x.to(self.device, dtype=self.dtype)
return x.to(self.device)
return (
tree.map_structure(to_device_with_dtype, backbone_inputs),
tree.map_structure(to_device_with_dtype, action_inputs),
)
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head(backbone_outputs, action_inputs)
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head.get_action(backbone_outputs, action_inputs, options)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
tune_vlln = kwargs.pop("tune_vlln", True)
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("revision", "cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
try:
local_model_path = snapshot_download(
pretrained_model_name_or_path,
repo_type="model",
revision=kwargs.get("revision"),
cache_dir=kwargs.get("cache_dir"),
local_files_only=kwargs.get("local_files_only", False),
token=kwargs.get("token"),
)
except (HFValidationError, RepositoryNotFoundError):
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path,
transformers_loading_kwargs=transformers_loading_kwargs,
load_backbone_weights=load_backbone_weights,
**kwargs,
)
pretrained_model.backbone.set_trainable_parameters(
tune_visual=tune_visual,
tune_llm=tune_llm,
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector,
tune_diffusion_model=tune_diffusion_model,
tune_vlln=tune_vlln,
)
return pretrained_model
def _register_with_transformers() -> None:
if AutoConfig is None or AutoModel is None:
return
try:
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
try:
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoModel.register(GR00TN17Config, GR00TN17)
_register_with_transformers()
+218 -47
View File
@@ -17,14 +17,8 @@
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
@@ -46,8 +40,14 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from .configuration_groot import GrootConfig
from .groot_n1 import GR00TN15
from .configuration_groot import (
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
T = TypeVar("T", bound="GrootPolicy")
@@ -67,27 +67,28 @@ class GrootPolicy(PreTrainedPolicy):
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self._action_queue_steps = self._resolve_action_queue_steps()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
model_kwargs = {
"pretrained_model_name_or_path": self.config.base_model_path,
"tune_llm": self.config.tune_llm,
"tune_visual": self.config.tune_visual,
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
}
from .groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
transformers_loading_kwargs={"trust_remote_code": True},
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
@@ -97,7 +98,7 @@ class GrootPolicy(PreTrainedPolicy):
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
self._action_queue = deque([], maxlen=self._action_queue_steps)
@classmethod
def from_pretrained(
@@ -118,7 +119,7 @@ class GrootPolicy(PreTrainedPolicy):
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
1. Base GR00T N1.7 models - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
@@ -141,8 +142,13 @@ class GrootPolicy(PreTrainedPolicy):
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
requested_version = (
normalize_groot_model_version(config.model_version)
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
@@ -193,8 +199,12 @@ class GrootPolicy(PreTrainedPolicy):
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
config = GrootConfig(
model_version=model_version,
base_model_path=str(pretrained_name_or_path),
)
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
@@ -215,6 +225,13 @@ class GrootPolicy(PreTrainedPolicy):
if hasattr(config, key):
setattr(config, key, value)
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
raise ValueError(
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -225,18 +242,161 @@ class GrootPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
execution_horizon = infer_groot_n1_7_action_execution_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
horizons = [n_action_steps]
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
horizons = [actions.shape[1]]
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
for horizon in (self.config.chunk_size, self.config.n_action_steps):
horizon = int(horizon)
if horizon > 0:
horizons.append(horizon)
return max(1, min(horizons))
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
allowed_base = {"state", "state_mask", "embodiment_id"}
if include_action:
allowed_base.update({"action", "action_mask"})
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
return {
k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
self,
inputs: dict[str, Tensor],
*,
inference_delay: object,
prev_chunk_left_over: object,
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if prev_chunk_left_over is None:
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
if prev_chunk_left_over.numel() == 0:
return inputs, None
prev_actions = prev_chunk_left_over
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state")
if state is None:
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
batch_size = state.shape[0]
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
# provided prefix row as a real action constraint, so strip that padding
# before constructing the native overlap options.
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
if valid_prefix_rows.any():
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
prev_actions = prev_actions[:, :valid_prefix_steps, :]
else:
return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
action_horizon = int(prev_actions.shape[1])
if action_horizon <= 0:
return inputs, None
if prev_actions.shape[2] > max_action_dim:
prev_actions = prev_actions[:, :, :max_action_dim]
elif prev_actions.shape[2] < max_action_dim:
pad = torch.zeros(
prev_actions.shape[0],
prev_actions.shape[1],
max_action_dim - prev_actions.shape[2],
dtype=prev_actions.dtype,
device=prev_actions.device,
)
prev_actions = torch.cat([prev_actions, pad], dim=2)
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
rtc_config = getattr(self.config, "rtc_config", None)
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
overlap_steps = max(0, min(action_horizon, execution_horizon))
if overlap_steps == 0:
return inputs, None
try:
frozen_steps = int(inference_delay or 0)
except (TypeError, ValueError):
frozen_steps = 0
frozen_steps = max(0, min(frozen_steps, overlap_steps))
options = {
"action_horizon": action_horizon,
"rtc_overlap_steps": overlap_steps,
"rtc_frozen_steps": frozen_steps,
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
}
inputs = dict(inputs)
inputs["action"] = prev_actions
return inputs, options
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
# Get device from model parameters
device = next(self.parameters()).device
@@ -254,32 +414,43 @@ class GrootPolicy(PreTrainedPolicy):
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
action-overlap options before calling the underlying model.
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
groot_options = None
if self.config.model_version == GROOT_N1_7:
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
if groot_options is not None:
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
else:
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
prediction_horizon = self._resolve_prediction_horizon(actions)
actions = actions[:, :prediction_horizon]
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
@@ -292,7 +463,7 @@ class GrootPolicy(PreTrainedPolicy):
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
File diff suppressed because it is too large Load Diff
-47
View File
@@ -1,47 +0,0 @@
from pathlib import Path
from shutil import copytree
from huggingface_hub import hf_hub_download
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)
File diff suppressed because it is too large Load Diff
@@ -1,444 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
pytest.importorskip("gr00t")
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local Groot installation and is not meant for CI",
)
from gr00t.data.dataset import ModalityConfig # noqa: E402
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
from gr00t.model.policy import Gr00tPolicy # noqa: E402
# GR1 humanoid dimensions (from pretrained model metadata)
# The actual GR1 robot has 44 dimensions for both state and action
# GR00TTransform will pad state to 64 and truncate action to 32
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = "cpu"
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
GR1_BODY_PARTS = {
"left_arm": 7,
"left_hand": 6,
"left_leg": 6,
"neck": 3,
"right_arm": 7,
"right_hand": 6,
"right_leg": 6,
"waist": 3,
}
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
) -> tuple[
GrootPolicy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_HORIZON,
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_groot_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
)
return (policy, preprocessor, postprocessor)
def instantiate_original_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
):
"""Instantiate original Groot policy from NVIDIA's implementation."""
from gr00t.data.transform.concat import ConcatTransform
from gr00t.data.transform.state_action import StateActionToTensor
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
from gr00t.model.transforms import GR00TTransform
video_keys = ["video.ego_view"]
state_keys = [
"state"
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
action_keys = [
"action.left_arm",
"action.right_arm",
"action.left_hand",
"action.right_hand",
"action.left_leg",
"action.right_leg",
"action.neck",
"action.waist",
]
language_keys = ["annotation.human.action.task_description"]
modality_config = {
"video": ModalityConfig(
delta_indices=[0], # Current frame only
modality_keys=video_keys,
),
"state": ModalityConfig(
delta_indices=[0],
modality_keys=state_keys,
),
"action": ModalityConfig(
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
modality_keys=action_keys,
),
"language": ModalityConfig(
delta_indices=[0],
modality_keys=language_keys,
),
}
modality_transform = ComposedModalityTransform(
transforms=[
VideoToTensor(apply_to=video_keys),
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
# State is already a single concatenated key, so no StateActionToTensor needed
# Convert action from numpy to tensor
StateActionToTensor(apply_to=action_keys),
# Concatenate only video and actions (state is already single key)
ConcatTransform(
video_concat_order=video_keys,
state_concat_order=[], # Empty:state is already single key
action_concat_order=action_keys,
),
GR00TTransform(
max_state_dim=64,
max_action_dim=32,
state_horizon=1,
action_horizon=DUMMY_ACTION_HORIZON,
training=False,
),
]
)
policy = Gr00tPolicy(
model_path=model_path,
embodiment_tag=EmbodimentTag.GR1,
modality_config=modality_config,
modality_transform=modality_transform,
device=DEVICE,
)
return policy, modality_config, modality_transform
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 2
prompt = "Pick up the red cube and place it in the bin"
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
batch = {
"observation.state": state,
"action": torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=device, # Action ground truth (for training)
),
"observation.images.ego_view": torch.rand(
batch_size,
3,
IMAGE_SIZE,
IMAGE_SIZE,
dtype=torch.float32,
device=device, # Images in [0, 1] range as expected by LeRobot
),
"task": [prompt for _ in range(batch_size)],
}
return batch
def convert_lerobot_to_original_format(batch, modality_config):
"""Convert LeRobot batch format to original Groot format.
The original Groot expects observations in this format:
{
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
"annotation.<annotation_type>": str or list[str]
}
"""
# Original Groot expects (T, H, W, C) format for images
# LeRobot has (B, C, H, W) format, so we need to convert
observation = {}
for img_key in ["ego_view"]:
lerobot_key = f"observation.images.{img_key}"
if lerobot_key in batch:
img = batch[lerobot_key]
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
# Convert [0, 1] to [0, 255] uint8 as expected by original
img_np = (img_np * 255).astype(np.uint8)
observation[f"video.{img_key}"] = img_np
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
if "observation.state" in batch:
state = batch["observation.state"]
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
observation["state"] = state_np
if "action" in batch:
action = batch["action"]
action_np = action.cpu().numpy()
start_idx = 0
for part_name, part_dim in GR1_BODY_PARTS.items():
end_idx = start_idx + part_dim
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
start_idx = end_idx
if "task" in batch:
task_list = batch["task"]
# GR00TTransform expects language with (B, T) shape for batched data
# Create a (B, T=1) array where each element is the string directly
bsz = len(task_list)
task_array = np.empty((bsz, 1), dtype=object)
for i in range(bsz):
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
observation["annotation.human.action.task_description"] = task_array
return observation
def test_groot_original_vs_lerobot_pretrained():
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
print("\n[LeRobot] Running inference...")
lerobot_policy.eval()
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
print("\n[Original] Running inference...")
original_policy.model.eval()
observation = convert_lerobot_to_original_format(batch, modality_config)
original_obs_transformed = modality_transform(deepcopy(observation))
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
original_model_output = original_policy.model.get_action(original_obs_transformed)
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
# Take first timestep
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
print("Action Comparison:")
diff = lerobot_actions - original_actions
abs_diff = torch.abs(diff)
for batch_idx in range(lerobot_actions.shape[0]):
print(f"\n{'=' * 60}")
print(f"Batch {batch_idx}")
print(f"{'=' * 60}")
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
print("-" * 60)
for action_idx in range(lerobot_actions.shape[1]):
lr_val = lerobot_actions[batch_idx, action_idx].item()
orig_val = original_actions[batch_idx, action_idx].item()
diff_val = abs(lr_val - orig_val)
sign = "+" if (lr_val - orig_val) > 0 else "-"
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
max_diff = abs_diff.max().item()
tolerance = 0.001
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
)
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation
cleanup_memory()
def test_groot_forward_pass_comparison():
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
print("Test: Forward Pass Comparison (Training Mode)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
lerobot_policy.eval()
original_policy.model.eval()
print("\n[LeRobot] Running forward pass...")
batch_lerobot = deepcopy(batch)
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
set_seed_all(42)
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
print(f" Loss: {lerobot_loss.item():.6f}")
print("\n[Original] Running forward pass...")
observation = convert_lerobot_to_original_format(batch, modality_config)
transformed_obs = modality_transform(observation)
if "action" not in transformed_obs:
action_for_forward = batch_lerobot_processed["action"]
action_mask_for_forward = batch_lerobot_processed["action_mask"]
# Match action horizon if needed
if action_for_forward.shape[1] != original_policy.model.action_horizon:
if action_for_forward.shape[1] < original_policy.model.action_horizon:
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
last_action = action_for_forward[:, -1:, :]
padding = last_action.repeat(1, pad_size, 1)
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
mask_padding = torch.zeros(
action_mask_for_forward.shape[0],
pad_size,
action_mask_for_forward.shape[2],
dtype=action_mask_for_forward.dtype,
device=action_mask_for_forward.device,
)
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
else:
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
action_mask_for_forward = action_mask_for_forward[
:, : original_policy.model.action_horizon, :
]
transformed_obs["action"] = action_for_forward
transformed_obs["action_mask"] = action_mask_for_forward
set_seed_all(42)
with torch.no_grad():
original_outputs = original_policy.model.forward(transformed_obs)
original_loss = original_outputs["loss"]
print(f" Loss: {original_loss.item():.6f}")
loss_diff = abs(lerobot_loss.item() - original_loss.item())
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
print("\nLoss Values:")
print(f" LeRobot: {lerobot_loss.item():.6f}")
print(f" Original: {original_loss.item():.6f}")
print(f" Absolute difference: {loss_diff:.6f}")
print(f" Relative difference: {loss_rel_diff:.2f}%")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation, transformed_obs
cleanup_memory()