removed n1.5 dependency

This commit is contained in:
nv-sachdevkartik
2026-06-04 22:14:07 +00:00
parent fd7fed08e2
commit a35ac22afd
12 changed files with 147 additions and 1679 deletions
+1 -1
View File
@@ -105,7 +105,7 @@ lerobot-train \
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
+1 -1
View File
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
- [SmolVLA](./smolvla)
- [Pi0.5](./pi05)
- [GR00T N1.5](./groot)
- [GR00T N1.7](./groot)
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
+4 -25
View File
@@ -2,11 +2,11 @@
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
LeRobot integrates GR00T through the `groot` policy type. The default model family is GR00T N1.5, and GR00T N1.7 can be selected with `policy.model_version=n1.7`.
LeRobot integrates GR00T N1.7 through the `groot` policy type.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. GR00T N1.7 extends the family with a Cosmos-Reason2/Qwen3-VL backbone and N1.7 checkpoints for SimplerEnv, DROID, and LIBERO.
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
@@ -63,12 +63,6 @@ If your CUDA/PyTorch build needs a different Flash Attention wheel or source bui
## Usage
To use GR00T N1.5 in your LeRobot configuration, specify the policy type:
```bash
--policy.type=groot
```
To use GR00T N1.7:
```bash
@@ -104,12 +98,6 @@ accelerate launch \
--job_name=$JOB_NAME
```
For N1.7, add:
```bash
--policy.model_version=n1.7
```
## Performance Results
### LIBERO Benchmark Results
@@ -117,16 +105,7 @@ For N1.7, add:
> [!NOTE]
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
GR00T has demonstrated strong performance on the LIBERO benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the LIBERO dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, follow the instructions in the [LIBERO](./libero) section.
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
### GR00T N1.7 LIBERO Checkpoints
@@ -198,4 +177,4 @@ lerobot-rollout\
## License
GR00T N1.5 follows NVIDIA's license terms, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
-1
View File
@@ -26,6 +26,5 @@ Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Models:
- GR00T N1.5: https://huggingface.co/nvidia/GR00T-N1.5-3B
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
@@ -110,7 +110,7 @@ class MultiEmbodimentActionEncoder(nn.Module):
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
"""Flow-matching action head used by GR00T backbones."""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
@@ -23,18 +23,12 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
GROOT_N1_5 = "n1.5"
GROOT_N1_7 = "n1.7"
GROOT_N1_5_BASE_MODEL = "nvidia/GR00T-N1.5-3B"
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.5": GROOT_N1_5,
"n1_5": GROOT_N1_5,
"n15": GROOT_N1_5,
"1.5": GROOT_N1_5,
"n1.7": GROOT_N1_7,
"n1_7": GROOT_N1_7,
"n1d7": GROOT_N1_7,
@@ -52,7 +46,7 @@ _GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = ", ".join(sorted({GROOT_N1_5, GROOT_N1_7}))
supported = GROOT_N1_7
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
@@ -80,8 +74,6 @@ def infer_groot_model_version(model_path: str | None) -> str | None:
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
@@ -296,9 +288,6 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn15", "gr00t_n1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
return None
@@ -335,15 +324,12 @@ class GrootConfig(PreTrainedConfig):
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. Defaults to N1.5 to preserve existing behavior.
model_version: str = GROOT_N1_5
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
# Path or HuggingFace model ID for the base Groot model
base_model_path: str | None = None
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
@@ -412,24 +398,18 @@ class GrootConfig(PreTrainedConfig):
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = (
GROOT_N1_7_BASE_MODEL if self.model_version == GROOT_N1_7 else GROOT_N1_5_BASE_MODEL
)
self.base_model_path = GROOT_N1_7_BASE_MODEL
if self.action_decode_transform is not None and self.model_version != GROOT_N1_7:
raise ValueError("action_decode_transform can only be used with model_version='n1.7'.")
if self.model_version == GROOT_N1_7:
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
@@ -513,11 +493,7 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
model_action_horizon = 16
if self.model_version == GROOT_N1_7:
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
return list(range(min(self.chunk_size, model_action_horizon)))
@property
-380
View File
@@ -1,380 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from huggingface_hub.dataclasses import strict
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
def strict(cls):
return cls
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
from .action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from .utils import ensure_eagle_cache_ready
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict[str, Any] | None = None
action_head_cfg: dict[str, Any] | None = None
action_horizon: int = 0
action_dim: int = 0
compute_dtype: str = "float32"
def __post_init__(self, **kwargs):
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
super().__post_init__(**kwargs)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
self.post_init()
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if ACTION in inputs:
action = inputs[ACTION]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model
+26 -65
View File
@@ -17,14 +17,8 @@
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
@@ -47,7 +41,6 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
@@ -55,7 +48,6 @@ from .configuration_groot import (
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
from .groot_n1 import GR00TN15
T = TypeVar("T", bound="GrootPolicy")
@@ -80,14 +72,7 @@ class GrootPolicy(PreTrainedPolicy):
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
@@ -98,16 +83,13 @@ class GrootPolicy(PreTrainedPolicy):
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
}
if self.config.model_version == GROOT_N1_7:
from .groot_n1_7 import GR00TN17
from .groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
transformers_loading_kwargs={"trust_remote_code": True},
)
else:
model = GR00TN15.from_pretrained(**model_kwargs)
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
transformers_loading_kwargs={"trust_remote_code": True},
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
@@ -137,7 +119,7 @@ class GrootPolicy(PreTrainedPolicy):
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
1. Base GR00T N1.7 models - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
@@ -163,7 +145,7 @@ class GrootPolicy(PreTrainedPolicy):
requested_version = (
normalize_groot_model_version(config.model_version)
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
@@ -217,7 +199,7 @@ class GrootPolicy(PreTrainedPolicy):
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
# Create default config with the pretrained path
config = GrootConfig(
model_version=model_version,
@@ -250,18 +232,6 @@ class GrootPolicy(PreTrainedPolicy):
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if config.model_version == GROOT_N1_7:
if config.max_state_dim == 64:
config.max_state_dim = 132
if config.max_action_dim == 32:
config.max_action_dim = 132
if config.chunk_size == 50:
config.chunk_size = 40
if config.n_action_steps == 50:
config.n_action_steps = 40
if tuple(config.image_size) == (224, 224):
config.image_size = (256, 256)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -274,9 +244,6 @@ class GrootPolicy(PreTrainedPolicy):
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
if self.config.model_version != GROOT_N1_7:
return n_action_steps
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
@@ -295,9 +262,6 @@ class GrootPolicy(PreTrainedPolicy):
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
if self.config.model_version != GROOT_N1_7:
return actions.shape[1]
horizons = [actions.shape[1]]
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
@@ -318,26 +282,23 @@ class GrootPolicy(PreTrainedPolicy):
if include_action:
allowed_base.update({"action", "action_mask"})
if self.config.model_version == GROOT_N1_7:
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
else:
allowed_base.update({"action_mask"} if include_action else set())
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
return {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
@@ -347,7 +308,7 @@ class GrootPolicy(PreTrainedPolicy):
inference_delay: object,
prev_chunk_left_over: object,
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
if self.config.model_version != GROOT_N1_7 or prev_chunk_left_over is None:
if prev_chunk_left_over is None:
return inputs, None
if not isinstance(prev_chunk_left_over, torch.Tensor):
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
+71 -570
View File
@@ -28,16 +28,10 @@ from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, ProcessorMixin
from transformers import ProcessorMixin
else:
AutoProcessor = None
ProcessorMixin = object
from lerobot.configs import (
FeatureType,
NormalizationMode,
PolicyFeature,
)
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -54,7 +48,6 @@ from lerobot.processor import (
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
ACTION,
HF_LEROBOT_HOME,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
@@ -64,15 +57,11 @@ from lerobot.utils.constants import (
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_7,
GROOT_N1_7_BACKBONE_MODEL,
GrootConfig,
is_raw_groot_n1_7_checkpoint,
)
# Defaults for Eagle processor locations
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
N1_7_EMBODIMENT_MAPPING = {
"oxe_droid_relative_eef_relative_joint": 24,
"xdof_relative_eef_relative_joint": 27,
@@ -470,14 +459,10 @@ def _legacy_groot_processor_overrides(
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = (
"groot_n1_7_pack_inputs_v1" if config.model_version == GROOT_N1_7 else "groot_pack_inputs_v3"
)
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
if dataset_stats is not None and config.model_version != GROOT_N1_7:
pack_input_overrides["stats"] = dataset_stats
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
try:
@@ -487,8 +472,6 @@ def _legacy_groot_processor_overrides(
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
if dataset_stats is not None and config.model_version != GROOT_N1_7:
action_unpack_overrides["stats"] = dataset_stats
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
@@ -528,13 +511,10 @@ def make_groot_pre_post_processors_from_pretrained(
dataset_stats=dataset_stats,
)
if (
config.model_version == GROOT_N1_7
and _local_processor_config_has_step(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
)
if _local_processor_config_has_step(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
):
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
@@ -596,180 +576,86 @@ def make_groot_pre_post_processors(
Tuple of (preprocessor, postprocessor) pipelines
"""
if config.model_version == GROOT_N1_7:
checkpoint_assets = _load_n1_7_checkpoint_processor_assets(config)
action_horizon = (
checkpoint_assets.max_action_horizon
if checkpoint_assets is not None and checkpoint_assets.max_action_horizon is not None
else min(config.chunk_size, 40)
)
valid_action_horizon = (
checkpoint_assets.valid_action_horizon
if checkpoint_assets is not None and checkpoint_assets.valid_action_horizon is not None
else action_horizon
)
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
embodiment_mapping = (
checkpoint_assets.embodiment_mapping
if checkpoint_assets is not None
else dict(N1_7_EMBODIMENT_MAPPING)
)
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
video_modality_keys = checkpoint_assets.video_modality_keys if checkpoint_assets is not None else None
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
pack_step = GrootN17PackInputsStep(
state_horizon=1,
action_horizon=action_horizon,
valid_action_horizon=valid_action_horizon,
video_horizon=checkpoint_assets.video_horizon if checkpoint_assets is not None else None,
max_state_dim=config.max_state_dim,
max_action_dim=config.max_action_dim,
language_key="task",
formalize_language=formalize_language,
embodiment_tag=config.embodiment_tag,
embodiment_mapping=embodiment_mapping,
normalize_min_max=True,
stats=padded_stats,
clip_outliers=clip_outliers,
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
state_cache_key=state_cache_key,
)
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
pack_step,
GrootN17VLMEncodeStep(
model_name=config.n1_7_backbone_model,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size
if checkpoint_assets is not None
else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge
if checkpoint_assets is not None
else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
),
DeviceProcessorStep(device=config.device),
]
if checkpoint_assets is None:
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
normalize_min_max=True,
clip_normalized_action=True,
)
else:
action_decode_step = GrootN17ActionDecodeStep(
env_action_dim=env_action_dim,
raw_stats=checkpoint_assets.raw_stats,
modality_config=checkpoint_assets.modality_config,
use_percentiles=checkpoint_assets.use_percentiles,
use_relative_action=checkpoint_assets.use_relative_action,
pack_step=pack_step,
state_cache_key=state_cache_key,
action_decode_transform=config.action_decode_transform,
)
output_steps: list[ProcessorStep] = [
action_decode_step,
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# Get horizon/dimension parameters from config
# These should match the config used for the pretrained model
# Default values match most GR00T configs (state_horizon=1, action_horizon=16)
state_horizon = 1
# CRITICAL: Pretrained GR00T models use action_horizon=16 max!
# The model architecture hardcodes this limit
action_horizon = min(config.chunk_size, 16)
max_state_dim = config.max_state_dim
max_action_dim = config.max_action_dim
# Pass raw dataset_stats; normalization will occur inside pack step before padding
padded_stats = dataset_stats or {}
# Define feature specs for optional normalization steps
_features: dict[str, PolicyFeature] = {
# Observation features (only add those we may normalize)
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
# Action feature
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
}
# Normalize STATE and ACTION with min_max (SO100-like default)
_norm_map = {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Determine env action dimension from config (simple, object-like PolicyFeature)
checkpoint_assets = _load_n1_7_checkpoint_processor_assets(config)
action_horizon = (
checkpoint_assets.max_action_horizon
if checkpoint_assets is not None and checkpoint_assets.max_action_horizon is not None
else min(config.chunk_size, 40)
)
valid_action_horizon = (
checkpoint_assets.valid_action_horizon
if checkpoint_assets is not None and checkpoint_assets.valid_action_horizon is not None
else action_horizon
)
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
embodiment_mapping = (
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
)
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
video_modality_keys = checkpoint_assets.video_modality_keys if checkpoint_assets is not None else None
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
pack_step = GrootN17PackInputsStep(
state_horizon=1,
action_horizon=action_horizon,
valid_action_horizon=valid_action_horizon,
video_horizon=checkpoint_assets.video_horizon if checkpoint_assets is not None else None,
max_state_dim=config.max_state_dim,
max_action_dim=config.max_action_dim,
language_key="task",
formalize_language=formalize_language,
embodiment_tag=config.embodiment_tag,
embodiment_mapping=embodiment_mapping,
normalize_min_max=True,
stats=padded_stats,
clip_outliers=clip_outliers,
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
state_cache_key=state_cache_key,
)
input_steps: list[ProcessorStep] = [
# 1. Rename keys if needed (e.g., dataset-specific camera names)
# Leave empty for now - add mappings if your dataset uses different key names
RenameObservationsProcessorStep(rename_map={}),
# 2. Add batch dimension for single samples
AddBatchDimensionProcessorStep(),
# 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding
GrootPackInputsStep(
state_horizon=state_horizon,
action_horizon=action_horizon,
max_state_dim=max_state_dim,
max_action_dim=max_action_dim,
language_key="task",
formalize_language=False,
embodiment_tag=config.embodiment_tag,
normalize_min_max=True,
stats=padded_stats,
pack_step,
GrootN17VLMEncodeStep(
model_name=config.n1_7_backbone_model,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
),
# 4. Eagle encode (creates eagle_content)
GrootEagleEncodeStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 5. Collate eagle_content -> eagle_* tensors
GrootEagleCollateStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 6. Move to device
DeviceProcessorStep(device=config.device),
]
# Postprocessing: slice to env action dim and unnormalize to env scale, then move to CPU
output_steps: list[ProcessorStep] = [
GrootActionUnpackUnnormalizeStep(
if checkpoint_assets is None:
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
normalize_min_max=True,
),
# Finally, move to CPU for env interaction
clip_normalized_action=True,
)
else:
action_decode_step = GrootN17ActionDecodeStep(
env_action_dim=env_action_dim,
raw_stats=checkpoint_assets.raw_stats,
modality_config=checkpoint_assets.modality_config,
use_percentiles=checkpoint_assets.use_percentiles,
use_relative_action=checkpoint_assets.use_relative_action,
pack_step=pack_step,
state_cache_key=state_cache_key,
action_decode_transform=config.action_decode_transform,
)
output_steps: list[ProcessorStep] = [
action_decode_step,
DeviceProcessorStep(device="cpu"),
]
@@ -801,10 +687,6 @@ def _to_uint8_np_bthwc(img_t: torch.Tensor) -> np.ndarray:
raise ValueError(f"Expected image tensor shape (B, C, H, W) or (B, T, C, H, W), got {tuple(img_t.shape)}")
def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray:
return _to_uint8_np_bthwc(img_t)[:, 0]
def _align_video_horizon(video: np.ndarray, horizon: int | None) -> np.ndarray:
"""Match the checkpoint video horizon by truncating or left-padding frames."""
@@ -865,29 +747,6 @@ def _prepare_n1_7_language_batch(
return formatted
def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin:
# Validate that the cache directory is ready. If not, instruct the user.
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
required = [
cache_dir / "processor_config.json",
cache_dir / "preprocessor_config.json",
cache_dir / "image_processing_eagle2_5_vl_fast.py",
]
if not all(p.exists() for p in required):
raise FileNotFoundError(
f"[GROOT] Eagle processor cache at '{cache_dir}' is not populated. "
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(
str(cache_dir),
trust_remote_code=True,
fix_mistral_regex=False,
)
proc.tokenizer.padding_side = "left"
return proc
def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> ProcessorMixin:
try:
from transformers import (
@@ -998,228 +857,6 @@ def _transform_n1_7_image_for_vlm(
return image
@dataclass
@ProcessorStepRegistry.register(name="groot_pack_inputs_v3")
class GrootPackInputsStep(ProcessorStep):
state_horizon: int = 1
action_horizon: int = 16
max_state_dim: int = 64
max_action_dim: int = 32
language_key: str = "task"
formalize_language: bool = False
embodiment_tag: str = "new_embodiment"
embodiment_mapping: dict[str, int] = field(
default_factory=lambda: {
"new_embodiment": 31, # Match original GR00T EMBODIMENT_TAG_MAPPING
"oxe_droid": 17,
"agibot_genie1": 26,
"gr1": 24,
"so100": 2,
"unitree_g1": 3,
}
)
# Min-max normalization (SO100-like) applied BEFORE padding
normalize_min_max: bool = True
stats: dict[str, dict[str, Any]] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor:
t = torch.as_tensor(vec)
t = t.flatten().to(
dtype=torch.float32,
device=next(
(v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu")
),
)
d = int(t.shape[-1]) if t.numel() > 0 else 0
if d == target_dim:
return t
if d < target_dim:
pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device)
return torch.cat([t, pad], dim=0)
return t[:target_dim]
def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor:
if not self.normalize_min_max:
return x
if self.stats is None or key not in self.stats:
return x
stats_k = self.stats[key]
last_dim = x.shape[-1]
min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0)
max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0)
denom = max_v - min_v
mask = denom != 0
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
mapped = 2 * (x - min_v) / safe_denom - 1
return torch.where(mask, mapped, torch.zeros_like(mapped))
# 1) Video (B, T=1, V, H, W, C) uint8
img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)])
if not img_keys and OBS_IMAGE in obs:
img_keys = [OBS_IMAGE]
if img_keys:
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
video = np.stack(cams, axis=1) # (B, V, H, W, C)
video = np.expand_dims(video, axis=1) # (B, 1, V, H, W, C)
# GR00T validates that video.shape[3] == 3 (channels), so reorder to (B, T, V, C, H, W)
video = np.transpose(video, (0, 1, 2, 5, 3, 4)) # (B, 1, V, C, H, W)
obs["video"] = video
# Drop raw images to avoid confusion downstream
for k in img_keys:
obs.pop(k, None)
# 2) Language (string)
lang = comp.get(self.language_key)
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else None
if not lang:
lang = "Perform the task."
if self.formalize_language:
lang = (lang or "").lower()
lang = "".join(ch for ch in lang if ch.isalnum() or ch.isspace())
comp["language"] = lang
# 3) State/state_mask -> (B, 1, max_state_dim)
if OBS_STATE in obs:
state = obs[OBS_STATE] # (B, D)
if state.dim() != 2:
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
bsz, d = state.shape
# Normalize BEFORE padding
if self.normalize_min_max:
state = _min_max_norm(state, OBS_STATE)
state = state.unsqueeze(1) # (B, 1, D)
if d > self.max_state_dim:
state = state[:, :, : self.max_state_dim]
d = self.max_state_dim
elif d < self.max_state_dim:
pad = torch.zeros(bsz, 1, self.max_state_dim - d, dtype=state.dtype, device=state.device)
state = torch.cat([state, pad], dim=2)
state_mask = torch.zeros(bsz, 1, self.max_state_dim, dtype=torch.bool, device=state.device)
state_mask[:, :, :d] = True
obs["state"] = state
obs["state_mask"] = state_mask
# 4) Action/action_mask -> (B, action_horizon, max_action_dim)
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
# Normalize BEFORE temporal expansion/padding
if self.normalize_min_max:
if action.dim() == 2:
action = _min_max_norm(action, ACTION)
elif action.dim() == 3:
b, t, d = action.shape
flat = action.reshape(b * t, d)
flat = _min_max_norm(flat, ACTION)
action = flat.view(b, t, d)
if action.dim() == 2:
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
elif action.dim() == 3:
b, t, d = action.shape
if t < self.action_horizon:
last = action[:, -1:, :]
pad = last.repeat(1, self.action_horizon - t, 1)
action = torch.cat([action, pad], dim=1)
elif t > self.action_horizon:
action = action[:, : self.action_horizon, :]
else:
raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}")
b, t, d = action.shape
if d > self.max_action_dim:
action = action[:, :, : self.max_action_dim]
d = self.max_action_dim
elif d < self.max_action_dim:
pad = torch.zeros(b, t, self.max_action_dim - d, dtype=action.dtype, device=action.device)
action = torch.cat([action, pad], dim=2)
action_mask = torch.zeros(b, t, self.max_action_dim, dtype=torch.bool, device=action.device)
action_mask[:, :, :d] = True
transition[TransitionKey.ACTION] = action
comp["action_mask"] = action_mask
# 5) Embodiment id as LongTensor (B,)
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
# Infer batch size/device from any tensor in obs or action
bsz = None
device = torch.device("cpu")
for v in list(obs.values()) + [transition.get(TransitionKey.ACTION)]:
if isinstance(v, torch.Tensor):
bsz = v.shape[0]
device = v.device
break
if bsz is None and "video" in obs and isinstance(obs["video"], np.ndarray):
bsz = obs["video"].shape[0]
if bsz is None:
bsz = 1
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.long, device=device)
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (we keep it simple)
def transform_features(self, features):
return features
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
Excludes 'stats' since they are saved separately via state_dict().
"""
return {
"state_horizon": self.state_horizon,
"action_horizon": self.action_horizon,
"max_state_dim": self.max_state_dim,
"max_action_dim": self.max_action_dim,
"language_key": self.language_key,
"formalize_language": self.formalize_language,
"embodiment_tag": self.embodiment_tag,
"embodiment_mapping": self.embodiment_mapping,
"normalize_min_max": self.normalize_min_max,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns normalization statistics as a flat state dictionary.
This enables saving stats to safetensors files, similar to normalizer_processor.
"""
if not self.stats:
return {}
flat: dict[str, torch.Tensor] = {}
for key, sub in self.stats.items():
for stat_name, value in sub.items():
tensor = torch.as_tensor(value).cpu()
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""
Loads normalization statistics from a flat state dictionary.
This enables loading stats from safetensors files during from_pretrained.
"""
if not state:
return
reconstructed: dict[str, dict[str, Any]] = {}
for flat_key, tensor in state.items():
if "." in flat_key:
key, stat_name = flat_key.rsplit(".", 1)
if key not in reconstructed:
reconstructed[key] = {}
reconstructed[key][stat_name] = tensor
if reconstructed:
self.stats = reconstructed
@dataclass
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
class GrootN17PackInputsStep(ProcessorStep):
@@ -1567,142 +1204,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"use_albumentations": self.use_albumentations,
}
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_encode_v3")
class GrootEagleEncodeStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
if "video" not in obs:
return transition
video = obs["video"] # (B, T, V, H, W, C) uint8
lang = comp.get("language", "Perform the task.")
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else "Perform the task."
bsz = video.shape[0]
eagle_contents: list[dict[str, Any]] = []
for b in range(bsz):
vt = video[b] # (T, V, C, H, W) after reorder
if vt.ndim != 5:
# Fallback: assume (T, V, H, W, C)
t, v, h, w, c = vt.shape
flat = rearrange(vt, "t v h w c -> (t v) h w c")
else:
t, v, c, h, w = vt.shape
flat = rearrange(vt, "t v c h w -> (t v) h w c")
images = [Image.fromarray(flat[i]) for i in range(t * v)]
# Format language as string list representation to match Original GROOT
lang_formatted = str([lang])
text_content = [{"type": "text", "text": lang_formatted}]
image_content = [{"type": "image", "image": img} for img in images]
conv = [{"role": "user", "content": image_content + text_content}]
text_list = [self.proc.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)]
img_inputs, vid_inputs = self.proc.process_vision_info(conv)
eagle_contents.append(
{
"text_list": text_list,
"image_inputs": img_inputs,
"video_inputs": vid_inputs,
}
)
comp["eagle_content"] = eagle_contents
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (no schema change here)
def transform_features(self, features):
return features
# Original GR00T-style collate: converts eagle_content -> eagle_* tensors
def collate(features: list[dict[str, Any]], eagle_processor: ProcessorMixin) -> dict[str, Any]:
batch: dict[str, Any] = {}
keys = features[0].keys()
for key in keys:
values = [elem[key] for elem in features]
if key == "eagle_content":
text_list: list[str] = []
image_inputs: list[Any] = []
for v in values:
curr_text_list = v["text_list"]
curr_image_inputs = v["image_inputs"]
text_list += curr_text_list
image_inputs += curr_image_inputs
eagle_inputs = eagle_processor(
text=text_list,
images=image_inputs,
images_kwargs={"min_dynamic_tiles": 1, "max_dynamic_tiles": 1, "use_thumbnail": False},
return_tensors="pt",
padding=True,
)
for k, v in eagle_inputs.items():
k = "eagle_" + k
batch[k] = v
elif key in ("pixel_values", "image_grid_thw", "attention_mask", "input_ids"):
# Concat in existing batch dimension.
batch[key] = torch.cat(values)
else:
# state, state_mask, action and action_mask.
# Stack to form the batch dimension.
batch[key] = torch.from_numpy(np.stack(values))
return batch
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_collate_v3")
class GrootEagleCollateStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
contents = comp.get("eagle_content")
if not contents:
return transition
# Build features list as original API expects: one dict per batch item
features = [{"eagle_content": content} for content in contents]
batched = collate(features, self.proc)
# Inject eagle_* tensors and remove the temporary content and raw video to free memory
for k, v in batched.items():
comp[k] = v
comp.pop("eagle_content", None)
obs.pop(
"video", None
) # The video has been fully encoded into eagle_* tensors, so we don't need the raw video anymore
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
def transform_features(self, features):
return features
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
for stat_name in ("mean", "q01", "min", "max", "std"):
value = entry.get(stat_name)
-47
View File
@@ -1,47 +0,0 @@
from pathlib import Path
from shutil import copytree
from huggingface_hub import hf_hub_download
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)
+28 -105
View File
@@ -29,8 +29,6 @@ from lerobot.configs import FeatureType, PolicyFeature
from lerobot.policies.factory import make_policy_config, make_pre_post_processors
from lerobot.policies.groot.configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5,
GROOT_N1_5_BASE_MODEL,
GROOT_N1_7,
GROOT_N1_7_BASE_MODEL,
GrootConfig,
@@ -40,7 +38,6 @@ from lerobot.policies.groot.configuration_groot import (
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import (
GrootActionUnpackUnnormalizeStep,
GrootEagleEncodeStep,
GrootN17ActionDecodeStep,
GrootN17PackInputsStep,
GrootN17VLMEncodeStep,
@@ -64,11 +61,9 @@ def _groot_features(
)
def _groot_config(model_version: str) -> GrootConfig:
def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig:
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
kwargs = {}
if model_version == GROOT_N1_7:
kwargs["action_decode_transform"] = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
kwargs = {"action_decode_transform": GROOT_ACTION_DECODE_TRANSFORM_LIBERO}
return GrootConfig(
model_version=model_version,
input_features=input_features,
@@ -347,19 +342,9 @@ class _DummyGrootModel(nn.Module):
return {"action_pred": torch.zeros(batch_size, 40, 132, device=self.weight.device)}
def test_groot_n1_5_defaults_are_preserved():
def test_groot_defaults_use_n1_7():
config = GrootConfig(device="cpu")
assert config.model_version == GROOT_N1_5
assert config.base_model_path == GROOT_N1_5_BASE_MODEL
assert config.max_state_dim == 64
assert config.max_action_dim == 32
assert len(config.action_delta_indices) == 16
def test_groot_n1_7_explicit_selection_uses_n1_7_defaults():
config = GrootConfig(model_version=GROOT_N1_7, device="cpu")
assert config.model_version == GROOT_N1_7
assert config.base_model_path == GROOT_N1_7_BASE_MODEL
assert config.max_state_dim == 132
@@ -389,25 +374,17 @@ def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy
)
def test_groot_n1_5_rejects_action_decode_transform():
with pytest.raises(ValueError, match="action_decode_transform"):
GrootConfig(
model_version=GROOT_N1_5,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
device="cpu",
)
def test_groot_n1_7_path_requires_matching_model_version():
with pytest.raises(ValueError, match="model_version"):
GrootConfig(base_model_path=GROOT_N1_7_BASE_MODEL, device="cpu")
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
def test_groot_rejects_n1_5_aliases(legacy_version):
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
GrootConfig(model_version=legacy_version, device="cpu")
def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7():
with pytest.raises(ValueError, match="does not match base_model_path"):
GrootConfig(
model_version=GROOT_N1_7,
base_model_path=GROOT_N1_5_BASE_MODEL,
base_model_path="nvidia/GR00T-N1.5-3B",
device="cpu",
)
@@ -528,7 +505,16 @@ def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(t
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
model_path = tmp_path / "GR00T-N1.7-local"
model_path.mkdir()
config = _groot_config(GROOT_N1_5)
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
with pytest.raises(ValueError, match="does not match base_model_path"):
GrootPolicy.from_pretrained(model_path, config=config)
@@ -1251,7 +1237,16 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(
model_path = tmp_path / "local-checkpoint"
model_path.mkdir()
(model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}')
config = _groot_config(GROOT_N1_5)
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
config = GrootConfig(
model_version=GROOT_N1_7,
base_model_path="nvidia/GR00T-N1.5-3B",
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
)
with pytest.raises(ValueError, match="does not match base_model_path"):
GrootPolicy.from_pretrained(model_path, config=config)
@@ -1266,20 +1261,9 @@ def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t():
assert GrootN17PackInputsStep in step_types
assert GrootN17VLMEncodeStep in step_types
assert GrootEagleEncodeStep not in step_types
assert "gr00t" not in sys.modules
def test_groot_n1_5_processors_still_use_eagle_path():
config = _groot_config(GROOT_N1_5)
preprocessor, _ = make_groot_pre_post_processors(config)
step_types = {type(step) for step in preprocessor.steps}
assert GrootEagleEncodeStep in step_types
assert GrootN17VLMEncodeStep not in step_types
def test_groot_n1_7_pack_inputs_preserves_per_sample_language():
step = GrootN17PackInputsStep(
action_horizon=2,
@@ -1672,67 +1656,6 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
assert unpack_step.env_action_dim == 7
def test_groot_legacy_n1_5_processors_reload_with_compatibility_overrides(tmp_path):
config = _groot_config(GROOT_N1_5)
dataset_stats = {
OBS_STATE: {
"min": torch.full((8,), -1.0),
"max": torch.full((8,), 1.0),
},
ACTION: {
"min": torch.full((7,), -2.0),
"max": torch.full((7,), 2.0),
},
}
legacy_preprocessor_config = {
"name": "policy_preprocessor",
"steps": [
{
"registry_name": "groot_pack_inputs_v3",
"config": {
"state_horizon": 1,
"action_horizon": 16,
"max_state_dim": config.max_state_dim,
"max_action_dim": config.max_action_dim,
"language_key": "task",
"formalize_language": False,
"embodiment_tag": config.embodiment_tag,
"embodiment_mapping": {"new_embodiment": 31},
"normalize_min_max": False,
},
}
],
}
legacy_postprocessor_config = {
"name": "policy_postprocessor",
"steps": [
{
"registry_name": "groot_action_unpack_unnormalize_v1",
"config": {
"env_action_dim": 0,
"normalize_min_max": False,
},
}
],
}
(tmp_path / "policy_preprocessor.json").write_text(json.dumps(legacy_preprocessor_config))
(tmp_path / "policy_postprocessor.json").write_text(json.dumps(legacy_postprocessor_config))
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(tmp_path),
dataset_stats=dataset_stats,
)
pack_step = loaded_preprocessor.steps[0]
unpack_step = loaded_postprocessor.steps[0]
assert pack_step.normalize_min_max
assert unpack_step.normalize_min_max
assert unpack_step.env_action_dim == 7
torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"])
torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"])
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
@@ -1,444 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script to verify Groot policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import gc
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.modeling_groot import GrootPolicy
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
pytest.importorskip("gr00t")
pytest.importorskip("transformers")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local Groot installation and is not meant for CI",
)
from gr00t.data.dataset import ModalityConfig # noqa: E402
from gr00t.data.embodiment_tags import EmbodimentTag # noqa: E402
from gr00t.data.transform.base import ComposedModalityTransform # noqa: E402
from gr00t.model.policy import Gr00tPolicy # noqa: E402
# GR1 humanoid dimensions (from pretrained model metadata)
# The actual GR1 robot has 44 dimensions for both state and action
# GR00TTransform will pad state to 64 and truncate action to 32
DUMMY_STATE_DIM = 44
DUMMY_ACTION_DIM = 44
DUMMY_ACTION_HORIZON = 16
IMAGE_SIZE = 256
DEVICE = "cpu"
MODEL_PATH = "nvidia/GR00T-N1.5-3B"
GR1_BODY_PARTS = {
"left_arm": 7,
"left_hand": 6,
"left_leg": 6,
"neck": 3,
"right_arm": 7,
"right_hand": 6,
"right_leg": 6,
"waist": 3,
}
def cleanup_memory():
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
print("\nCleaning up memory...")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
print("Memory cleanup complete.")
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Set deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True, warn_only=True)
def instantiate_lerobot_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
) -> tuple[
GrootPolicy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Instantiate LeRobot Groot policy with preprocessor and postprocessor."""
if from_pretrained:
policy = GrootPolicy.from_pretrained(
pretrained_name_or_path=model_path,
strict=False,
)
policy.config.embodiment_tag = "gr1"
else:
config = GrootConfig(
base_model_path=model_path,
n_action_steps=DUMMY_ACTION_HORIZON,
chunk_size=DUMMY_ACTION_HORIZON,
image_size=[IMAGE_SIZE, IMAGE_SIZE],
device=DEVICE,
embodiment_tag="gr1",
)
policy = GrootPolicy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_groot_pre_post_processors(
config=policy.config,
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original GR00T doesn't normalize)
)
return (policy, preprocessor, postprocessor)
def instantiate_original_groot(
from_pretrained: bool = False,
model_path: str = MODEL_PATH,
):
"""Instantiate original Groot policy from NVIDIA's implementation."""
from gr00t.data.transform.concat import ConcatTransform
from gr00t.data.transform.state_action import StateActionToTensor
from gr00t.data.transform.video import VideoToNumpy, VideoToTensor
from gr00t.model.transforms import GR00TTransform
video_keys = ["video.ego_view"]
state_keys = [
"state"
] # Important: Use single concatenated "state" key (not split body parts) to match preprocessing
action_keys = [
"action.left_arm",
"action.right_arm",
"action.left_hand",
"action.right_hand",
"action.left_leg",
"action.right_leg",
"action.neck",
"action.waist",
]
language_keys = ["annotation.human.action.task_description"]
modality_config = {
"video": ModalityConfig(
delta_indices=[0], # Current frame only
modality_keys=video_keys,
),
"state": ModalityConfig(
delta_indices=[0],
modality_keys=state_keys,
),
"action": ModalityConfig(
delta_indices=list(range(DUMMY_ACTION_HORIZON)),
modality_keys=action_keys,
),
"language": ModalityConfig(
delta_indices=[0],
modality_keys=language_keys,
),
}
modality_transform = ComposedModalityTransform(
transforms=[
VideoToTensor(apply_to=video_keys),
VideoToNumpy(apply_to=video_keys), # Convert to numpy (GR00TTransform expects numpy arrays)
# State is already a single concatenated key, so no StateActionToTensor needed
# Convert action from numpy to tensor
StateActionToTensor(apply_to=action_keys),
# Concatenate only video and actions (state is already single key)
ConcatTransform(
video_concat_order=video_keys,
state_concat_order=[], # Empty:state is already single key
action_concat_order=action_keys,
),
GR00TTransform(
max_state_dim=64,
max_action_dim=32,
state_horizon=1,
action_horizon=DUMMY_ACTION_HORIZON,
training=False,
),
]
)
policy = Gr00tPolicy(
model_path=model_path,
embodiment_tag=EmbodimentTag.GR1,
modality_config=modality_config,
modality_transform=modality_transform,
device=DEVICE,
)
return policy, modality_config, modality_transform
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 2
prompt = "Pick up the red cube and place it in the bin"
state = torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device)
batch = {
"observation.state": state,
"action": torch.randn(
batch_size,
DUMMY_ACTION_HORIZON,
DUMMY_ACTION_DIM,
dtype=torch.float32,
device=device, # Action ground truth (for training)
),
"observation.images.ego_view": torch.rand(
batch_size,
3,
IMAGE_SIZE,
IMAGE_SIZE,
dtype=torch.float32,
device=device, # Images in [0, 1] range as expected by LeRobot
),
"task": [prompt for _ in range(batch_size)],
}
return batch
def convert_lerobot_to_original_format(batch, modality_config):
"""Convert LeRobot batch format to original Groot format.
The original Groot expects observations in this format:
{
"video.<camera_name>": np.ndarray (T, H, W, C) or (B, T, H, W, C)
"state.<state_component>": np.ndarray (T, D) or (B, T, D)
"action.<action_component>": np.ndarray (T, D) or (B, T, D)
"annotation.<annotation_type>": str or list[str]
}
"""
# Original Groot expects (T, H, W, C) format for images
# LeRobot has (B, C, H, W) format, so we need to convert
observation = {}
for img_key in ["ego_view"]:
lerobot_key = f"observation.images.{img_key}"
if lerobot_key in batch:
img = batch[lerobot_key]
# Convert from (B, C, H, W) to (B, T=1, H, W, C)
img_np = img.permute(0, 2, 3, 1).unsqueeze(1).cpu().numpy()
# Convert [0, 1] to [0, 255] uint8 as expected by original
img_np = (img_np * 255).astype(np.uint8)
observation[f"video.{img_key}"] = img_np
# Important: The Original's GR00TTransform expects "state" as (B, T, D), not split body parts
if "observation.state" in batch:
state = batch["observation.state"]
state_np = state.unsqueeze(1).cpu().numpy() # (B, 1, D)
observation["state"] = state_np
if "action" in batch:
action = batch["action"]
action_np = action.cpu().numpy()
start_idx = 0
for part_name, part_dim in GR1_BODY_PARTS.items():
end_idx = start_idx + part_dim
observation[f"action.{part_name}"] = action_np[:, :, start_idx:end_idx]
start_idx = end_idx
if "task" in batch:
task_list = batch["task"]
# GR00TTransform expects language with (B, T) shape for batched data
# Create a (B, T=1) array where each element is the string directly
bsz = len(task_list)
task_array = np.empty((bsz, 1), dtype=object)
for i in range(bsz):
task_array[i, 0] = task_list[i] # Assign string directly to each (i, 0) position
observation["annotation.human.action.task_description"] = task_array
return observation
def test_groot_original_vs_lerobot_pretrained():
"""Test Groot original implementation vs LeRobot implementation with pretrained weights."""
print("Test: Groot Original vs LeRobot with Pretrained Weights (Inference)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
print("\n[LeRobot] Running inference...")
lerobot_policy.eval()
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
lerobot_actions = lerobot_policy.select_action(batch_lerobot_processed)
print("\n[Original] Running inference...")
original_policy.model.eval()
observation = convert_lerobot_to_original_format(batch, modality_config)
original_obs_transformed = modality_transform(deepcopy(observation))
# Important: Reset seed immediately before inference to ensure identical RNG state
torch.manual_seed(42)
with torch.no_grad():
original_model_output = original_policy.model.get_action(original_obs_transformed)
original_actions_raw = original_model_output["action_pred"] # [2, 16, 32]
# Take first timestep
original_actions = original_actions_raw[:, 0, :].to(lerobot_actions.device).to(lerobot_actions.dtype)
print("Action Comparison:")
diff = lerobot_actions - original_actions
abs_diff = torch.abs(diff)
for batch_idx in range(lerobot_actions.shape[0]):
print(f"\n{'=' * 60}")
print(f"Batch {batch_idx}")
print(f"{'=' * 60}")
print(f"{'Idx':<5} {'LeRobot':<14} {'Original':<14} {'Difference':<14}")
print("-" * 60)
for action_idx in range(lerobot_actions.shape[1]):
lr_val = lerobot_actions[batch_idx, action_idx].item()
orig_val = original_actions[batch_idx, action_idx].item()
diff_val = abs(lr_val - orig_val)
sign = "+" if (lr_val - orig_val) > 0 else "-"
print(f"{action_idx:<5} {lr_val:>13.6f} {orig_val:>13.6f} {sign}{diff_val:>12.6f}")
max_diff = abs_diff.max().item()
tolerance = 0.001
assert torch.allclose(lerobot_actions, original_actions, atol=tolerance), (
f"Actions differ by more than tolerance ({tolerance}): max diff = {max_diff:.6f}"
)
print(f"\nSuccess: Actions match within tolerance ({tolerance})!")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation
cleanup_memory()
def test_groot_forward_pass_comparison():
"""Test forward pass comparison between LeRobot and Original Groot implementations."""
print("Test: Forward Pass Comparison (Training Mode)")
set_seed_all(42)
lerobot_policy, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_groot(
from_pretrained=True
)
original_policy, modality_config, modality_transform = instantiate_original_groot(from_pretrained=True)
batch = create_dummy_data()
lerobot_policy.eval()
original_policy.model.eval()
print("\n[LeRobot] Running forward pass...")
batch_lerobot = deepcopy(batch)
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
set_seed_all(42)
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
print(f" Loss: {lerobot_loss.item():.6f}")
print("\n[Original] Running forward pass...")
observation = convert_lerobot_to_original_format(batch, modality_config)
transformed_obs = modality_transform(observation)
if "action" not in transformed_obs:
action_for_forward = batch_lerobot_processed["action"]
action_mask_for_forward = batch_lerobot_processed["action_mask"]
# Match action horizon if needed
if action_for_forward.shape[1] != original_policy.model.action_horizon:
if action_for_forward.shape[1] < original_policy.model.action_horizon:
pad_size = original_policy.model.action_horizon - action_for_forward.shape[1]
last_action = action_for_forward[:, -1:, :]
padding = last_action.repeat(1, pad_size, 1)
action_for_forward = torch.cat([action_for_forward, padding], dim=1)
mask_padding = torch.zeros(
action_mask_for_forward.shape[0],
pad_size,
action_mask_for_forward.shape[2],
dtype=action_mask_for_forward.dtype,
device=action_mask_for_forward.device,
)
action_mask_for_forward = torch.cat([action_mask_for_forward, mask_padding], dim=1)
else:
action_for_forward = action_for_forward[:, : original_policy.model.action_horizon, :]
action_mask_for_forward = action_mask_for_forward[
:, : original_policy.model.action_horizon, :
]
transformed_obs["action"] = action_for_forward
transformed_obs["action_mask"] = action_mask_for_forward
set_seed_all(42)
with torch.no_grad():
original_outputs = original_policy.model.forward(transformed_obs)
original_loss = original_outputs["loss"]
print(f" Loss: {original_loss.item():.6f}")
loss_diff = abs(lerobot_loss.item() - original_loss.item())
loss_rel_diff = loss_diff / (abs(original_loss.item()) + 1e-8) * 100
print("\nLoss Values:")
print(f" LeRobot: {lerobot_loss.item():.6f}")
print(f" Original: {original_loss.item():.6f}")
print(f" Absolute difference: {loss_diff:.6f}")
print(f" Relative difference: {loss_rel_diff:.2f}%")
del lerobot_policy, lerobot_preprocessor, lerobot_postprocessor
del original_policy, modality_config, modality_transform
del batch, batch_lerobot, observation, transformed_obs
cleanup_memory()