diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a216548d8..a1fc2660f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -64,7 +64,7 @@ - local: eo1 title: EO-1 - local: groot - title: NVIDIA GR00T N1.5 + title: NVIDIA GR00T - local: xvla title: X-VLA - local: multi_task_dit diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index a10b5e369..706ab4b78 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -1,16 +1,16 @@ -# GR00T N1.5 Policy +# GR00T Policy -GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments. +GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments. -This document outlines the specifics of its integration and usage within the LeRobot framework. +LeRobot integrates GR00T through the `groot` policy type. The default model family is GR00T N1.5, and GR00T N1.7 can be selected with `policy.model_version=n1.7`. ## Model Overview -NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots. +NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. GR00T N1.7 extends the family with a Cosmos-Reason2/Qwen3-VL backbone and N1.7 checkpoints for SimplerEnv, DROID, and LIBERO. -Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks. +Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks. -GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception. +GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception. =2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX -pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies -pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation -python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')" +pip install "lerobot[groot]" ``` -3. Install LeRobot by running: +GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra installs the policy dependencies, including `transformers`, `diffusers`, `peft`, `dm-tree`, and Flash Attention where available. If Flash Attention is unavailable or incompatible, LeRobot falls back to SDPA attention in supported GR00T paths, with lower expected throughput. + +For a source checkout, follow the Environment Setup in the [Installation Guide](./installation), then install the extra: ```bash -pip install lerobot[groot] +uv sync --locked --extra groot ``` +If you need to install Flash Attention manually for your CUDA/PyTorch build, use the wheel or source build recommended by the [Flash Attention project](https://github.com/Dao-AILab/flash-attention). + ## Usage -To use GR00T in your LeRobot configuration, specify the policy type as: +To use GR00T N1.5 in your LeRobot configuration, specify the policy type: -```python -policy.type=groot +```bash +--policy.type=groot +``` + +To use GR00T N1.7: + +```bash +--policy.type=groot \ +--policy.model_version=n1.7 ``` ## Training @@ -85,14 +87,20 @@ accelerate launch \ --job_name=$JOB_NAME ``` +For N1.7, add: + +```bash +--policy.model_version=n1.7 +``` + ## Performance Results -### Libero Benchmark Results +### LIBERO Benchmark Results > [!NOTE] -> Follow our instructions for Libero usage: [Libero](./libero) +> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`. -GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results. +GR00T has demonstrated strong performance on the LIBERO benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the LIBERO dataset and compared the results to the GR00T reference results. | Benchmark | LeRobot Implementation | GR00T Reference | | ------------------ | ---------------------- | --------------- | @@ -101,7 +109,49 @@ GR00T has demonstrated strong performance on the Libero benchmark suite. To comp | **Libero Long** | 82.0% | 76.0% | | **Average** | 87.0% | 87.0% | -These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section. +These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, follow the instructions in the [LIBERO](./libero) section. + +### GR00T N1.7 LIBERO Checkpoints + +NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite: + +| Suite | Checkpoint subdirectory | +| -------------- | ----------------------- | +| LIBERO Spatial | `libero_spatial` | +| LIBERO Object | `libero_object` | +| LIBERO Goal | `libero_goal` | +| LIBERO 10 | `libero_10` | + +Preliminary LeRobot integration results: + +| Suite | Status | Success rate | n_episodes | +| -------------- | ------ | -----------: | ---------: | +| LIBERO Spatial | ✓ | ~95% | XX | +| LIBERO Object | ✓ | XX% | XX | +| LIBERO Goal | ✓ | XX% | XX | +| LIBERO 10 | ✓ | XX% | XX | +| **Average** | ✓ | **XX%** | **XX** | + +Replace the `XX` placeholders with final eval artifacts before merge. + +Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field. + +```bash +huggingface-cli download nvidia/GR00T-N1.7-LIBERO \ + --include "libero_spatial/*" \ + --local-dir ./GR00T-N1.7-LIBERO + +lerobot-eval \ + --policy.type=groot \ + --policy.model_version=n1.7 \ + --policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \ + --policy.embodiment_tag=libero_sim \ + --env.type=libero \ + --env.task=libero_spatial \ + --eval.n_episodes=50 +``` + +Use `eval.n_episodes >= 50` per suite when reporting success rates. ### Evaluate in your hardware setup @@ -131,4 +181,4 @@ lerobot-rollout\ ## License -This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**. +GR00T N1.5 follows NVIDIA's license terms, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/). diff --git a/docs/source/policy_groot_README.md b/docs/source/policy_groot_README.md index efcd76ebe..187301048 100644 --- a/docs/source/policy_groot_README.md +++ b/docs/source/policy_groot_README.md @@ -24,4 +24,8 @@ Code: https://github.com/NVIDIA/Isaac-GR00T Blog: https://developer.nvidia.com/isaac/gr00t -Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B +Hugging Face Models: + +- GR00T N1.5: https://huggingface.co/nvidia/GR00T-N1.5-3B +- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B +- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 05fda05d8..50f3e0f27 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -18,6 +18,7 @@ from __future__ import annotations import importlib import logging +from copy import copy from typing import TYPE_CHECKING, Any, TypedDict, Unpack import torch @@ -48,7 +49,7 @@ from .act.configuration_act import ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig from .eo1.configuration_eo1 import EO1Config from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig -from .groot.configuration_groot import GrootConfig +from .groot.configuration_groot import GROOT_N1_7, GrootConfig from .molmoact2.configuration_molmoact2 import MolmoAct2Config from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config @@ -273,24 +274,47 @@ def make_pre_post_processors( policy configuration type. """ if pretrained_path: + if isinstance(policy_cfg, GrootConfig): + from .groot.configuration_groot import is_raw_groot_n1_7_checkpoint + + if is_raw_groot_n1_7_checkpoint(pretrained_path): + from .groot.processor_groot import make_groot_pre_post_processors + + processor_cfg = copy(policy_cfg) + processor_cfg.base_model_path = str(pretrained_path) + return make_groot_pre_post_processors( + config=processor_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + # TODO(Steven): Temporary patch, implement correctly the processors for Gr00t if isinstance(policy_cfg, GrootConfig): - # GROOT handles normalization in groot_pack_inputs_v3 step + # GROOT handles normalization in its pack-inputs step # Need to override both stats AND normalize_min_max since saved config might be empty - preprocessor_overrides = {} - postprocessor_overrides = {} - preprocessor_overrides["groot_pack_inputs_v3"] = { - "stats": kwargs.get("dataset_stats"), - "normalize_min_max": True, - } + dataset_stats = kwargs.get("dataset_stats") + preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {})) + postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {})) + pack_inputs_key = ( + "groot_n1_7_pack_inputs_v1" + if policy_cfg.model_version == GROOT_N1_7 + else "groot_pack_inputs_v3" + ) + pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {})) + pack_input_overrides["normalize_min_max"] = True + if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7: + pack_input_overrides["stats"] = dataset_stats + preprocessor_overrides[pack_inputs_key] = pack_input_overrides # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats env_action_dim = policy_cfg.output_features[ACTION].shape[0] - postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = { - "stats": kwargs.get("dataset_stats"), - "normalize_min_max": True, - "env_action_dim": env_action_dim, - } + action_unpack_overrides = dict( + postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}) + ) + action_unpack_overrides["normalize_min_max"] = True + action_unpack_overrides["env_action_dim"] = env_action_dim + if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7: + action_unpack_overrides["stats"] = dataset_stats + postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides kwargs["preprocessor_overrides"] = preprocessor_overrides kwargs["postprocessor_overrides"] = postprocessor_overrides diff --git a/src/lerobot/policies/groot/__init__.py b/src/lerobot/policies/groot/__init__.py index c8933ff56..dd67cc5fb 100644 --- a/src/lerobot/policies/groot/__init__.py +++ b/src/lerobot/policies/groot/__init__.py @@ -18,4 +18,12 @@ from .configuration_groot import GrootConfig from .modeling_groot import GrootPolicy from .processor_groot import make_groot_pre_post_processors -__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"] +__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"] + + +def __getattr__(name: str): + if name in {"GR00TN17", "GR00TN17Config"}: + from .groot_n1_7 import GR00TN17, GR00TN17Config + + return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py index a4cd1a0b7..0991ef029 100755 --- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -181,8 +181,7 @@ class BasicTransformerBlock(nn.Module): attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - # encoder_attention_mask=encoder_attention_mask, + attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask, ) if self.final_dropout: attn_output = self.final_dropout(attn_output) @@ -318,6 +317,71 @@ class DiT(ModelMixin, ConfigMixin): return self.proj_out_2(hidden_states) +class AlternateVLDiT(DiT): + """N1.7 DiT variant that alternates cross-attention over image and text tokens.""" + + def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs): + super().__init__(*args, **kwargs) + self.attend_text_every_n_blocks = attend_text_every_n_blocks + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_all_hidden_states: bool = False, + image_mask: torch.Tensor | None = None, + backbone_attention_mask: torch.Tensor | None = None, + ): + if image_mask is None: + raise ValueError("image_mask is required for AlternateVLDiT.") + if backbone_attention_mask is None: + raise ValueError("backbone_attention_mask is required for AlternateVLDiT.") + + temb = self.timestep_encoder(timestep) + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + + image_attention_mask = image_mask & backbone_attention_mask + non_image_attention_mask = (~image_mask) & backbone_attention_mask + + all_hidden_states = [hidden_states] + if not self.config.interleave_self_attention: + raise ValueError("AlternateVLDiT requires interleave_self_attention=True.") + + for idx, block in enumerate(self.transformer_blocks): + if idx % 2 == 1: + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + temb=temb, + ) + else: + curr_encoder_attention_mask = ( + non_image_attention_mask + if idx % (2 * self.attend_text_every_n_blocks) == 0 + else image_attention_mask + ) + hidden_states = block( + hidden_states, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=curr_encoder_attention_mask, + temb=temb, + ) + all_hidden_states.append(hidden_states) + + conditioning = temb + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + if return_all_hidden_states: + return self.proj_out_2(hidden_states), all_hidden_states + return self.proj_out_2(hidden_states) + + class SelfAttentionTransformer(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 17cb631d7..78b8edaf6 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -14,12 +14,295 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import os from dataclasses import dataclass, field +from pathlib import Path from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE +GROOT_N1_5 = "n1.5" +GROOT_N1_7 = "n1.7" +GROOT_N1_5_BASE_MODEL = "nvidia/GR00T-N1.5-3B" +GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" +GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" +GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" + +_GROOT_MODEL_VERSION_ALIASES = { + "n1.5": GROOT_N1_5, + "n1_5": GROOT_N1_5, + "n15": GROOT_N1_5, + "1.5": GROOT_N1_5, + "n1.7": GROOT_N1_7, + "n1_7": GROOT_N1_7, + "n1d7": GROOT_N1_7, + "n17": GROOT_N1_7, + "1.7": GROOT_N1_7, +} + +_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = { + "none": None, + "": None, + GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO, +} + + +def normalize_groot_model_version(model_version: str) -> str: + normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower()) + if normalized is None: + supported = ", ".join(sorted({GROOT_N1_5, GROOT_N1_7})) + raise ValueError( + f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}." + ) + return normalized + + +def normalize_groot_action_decode_transform(transform: str | None) -> str | None: + if transform is None: + return None + normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower()) + if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES: + supported = ", ".join( + sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None) + ) + raise ValueError( + f"Unsupported GR00T N1.7 action decode transform '{transform}'. " + f"Supported transforms: none, {supported}." + ) + return normalized + + +def infer_groot_model_version(model_path: str | None) -> str | None: + if not model_path: + return None + model_path_lower = model_path.lower() + if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower: + return GROOT_N1_7 + if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower: + return GROOT_N1_5 + config_version = _infer_groot_model_version_from_local_config(model_path) + if config_version is not None: + return config_version + return None + + +def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool: + if model_path is None: + return False + + path = Path(model_path).expanduser() + if path.is_dir(): + config_path = path / "config.json" + elif path.name == "config.json": + config_path = path + else: + return False + + try: + with config_path.open() as f: + config = json.load(f) + except (OSError, json.JSONDecodeError): + return False + + return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7 + + +def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None: + if model_path is None: + return None + + processor_config_path = Path(model_path).expanduser() / "processor_config.json" + try: + with processor_config_path.open() as f: + processor_config = json.load(f) + except (OSError, json.JSONDecodeError): + return None + + modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return None + if "libero_sim" in modality_configs: + return "libero_sim" + if len(modality_configs) == 1: + return next(iter(modality_configs)) + return None + + +def infer_groot_n1_7_action_horizon( + model_path: str | Path | None, embodiment_tag: str | None = None +) -> int | None: + if model_path is None: + return None + + processor_config_path = Path(model_path).expanduser() / "processor_config.json" + try: + with processor_config_path.open() as f: + processor_config = json.load(f) + except (OSError, json.JSONDecodeError): + return None + + processor_kwargs = processor_config.get("processor_kwargs", {}) + if not isinstance(processor_kwargs, dict): + return None + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return None + + if embodiment_tag is None: + embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path) + if embodiment_tag is None: + return None + + embodiment_config = modality_configs.get(embodiment_tag, {}) + if not isinstance(embodiment_config, dict): + return None + action_config = embodiment_config.get("action", {}) + if not isinstance(action_config, dict): + return None + delta_indices = action_config.get("delta_indices", []) + if not isinstance(delta_indices, list): + return None + return len(delta_indices) or None + + +def infer_groot_n1_7_action_execution_horizon( + model_path: str | Path | None, embodiment_tag: str | None = None +) -> int | None: + action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag) + if action_horizon is None: + return None + + if embodiment_tag is None: + embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path) + if embodiment_tag == "libero_sim": + # NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded + # actions. Keeping that execution cadence avoids stale open-loop chunks. + return min(action_horizon, 8) + return action_horizon + + +def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str: + model_path = Path(model_name).expanduser() + if model_path.exists(): + return str(model_path) + + cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir) + return str(cached_snapshot) if cached_snapshot is not None else model_name + + +def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None: + repo_cache_name = f"models--{repo_id.replace('/', '--')}" + required_files = ( + "config.json", + "tokenizer_config.json", + "preprocessor_config.json", + "video_preprocessor_config.json", + ) + + for hub_cache in _candidate_hf_hub_caches(cache_dir): + repo_cache = hub_cache / repo_cache_name + snapshots_dir = repo_cache / "snapshots" + if not snapshots_dir.is_dir(): + continue + + candidates: list[Path] = [] + ref_path = repo_cache / "refs" / "main" + try: + ref = ref_path.read_text().strip() + except OSError: + ref = "" + if ref: + candidates.append(snapshots_dir / ref) + candidates.extend( + sorted( + (path for path in snapshots_dir.iterdir() if path.is_dir()), + key=lambda path: path.stat().st_mtime, + reverse=True, + ) + ) + + seen: set[Path] = set() + for snapshot in candidates: + if snapshot in seen: + continue + seen.add(snapshot) + if all((snapshot / filename).exists() for filename in required_files): + return snapshot + return None + + +def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]: + candidates: list[Path] = [] + if cache_dir is not None: + cache_path = Path(cache_dir).expanduser() + candidates.append(cache_path) + candidates.append(cache_path / "hub") + + hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE") + if hub_cache: + candidates.append(Path(hub_cache).expanduser()) + + hf_home = os.environ.get("HF_HOME") + if hf_home: + candidates.append(Path(hf_home).expanduser() / "hub") + + candidates.append(Path.home() / ".cache" / "huggingface" / "hub") + + deduped: list[Path] = [] + seen: set[Path] = set() + for candidate in candidates: + resolved = candidate.resolve() if candidate.exists() else candidate + if resolved not in seen: + seen.add(resolved) + deduped.append(candidate) + return deduped + + +def _infer_groot_model_version_from_local_config(model_path: str) -> str | None: + path = Path(model_path).expanduser() + if path.is_dir(): + config_path = path / "config.json" + elif path.name == "config.json": + config_path = path + else: + return None + + if not config_path.exists(): + return None + + try: + with config_path.open() as f: + config = json.load(f) + except (OSError, json.JSONDecodeError): + return None + + return _infer_groot_model_version_from_config(config) + + +def _infer_groot_model_version_from_config(config: dict) -> str | None: + model_version = config.get("model_version") + if isinstance(model_version, str): + try: + return normalize_groot_model_version(model_version) + except ValueError: + return None + + candidates = [config.get("model_type"), *(config.get("architectures") or [])] + for candidate in candidates: + if not isinstance(candidate, str): + continue + normalized = candidate.lower().replace("-", "_") + if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}: + return GROOT_N1_7 + if normalized in {"gr00t_n1_5", "gr00tn15", "gr00t_n1d5"}: + return GROOT_N1_5 + + if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL: + return GROOT_N1_7 + return None + @PreTrainedConfig.register_subclass("groot") @dataclass @@ -52,12 +335,21 @@ class GrootConfig(PreTrainedConfig): # Groot-specific model parameters (from groot_finetune_script.py) + # Explicit GR00T model family selection. Defaults to N1.5 to preserve existing behavior. + model_version: str = GROOT_N1_5 + # Path or HuggingFace model ID for the base Groot model - base_model_path: str = "nvidia/GR00T-N1.5-3B" + base_model_path: str | None = None # HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer. tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5" + # HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor. + n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL + + # Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step(). + action_decode_transform: str | None = None + # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') embodiment_tag: str = "new_embodiment" @@ -117,6 +409,35 @@ class GrootConfig(PreTrainedConfig): resume: bool = False def __post_init__(self): + self.model_version = normalize_groot_model_version(self.model_version) + self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform) + if self.base_model_path is None: + self.base_model_path = ( + GROOT_N1_7_BASE_MODEL if self.model_version == GROOT_N1_7 else GROOT_N1_5_BASE_MODEL + ) + + if self.action_decode_transform is not None and self.model_version != GROOT_N1_7: + raise ValueError("action_decode_transform can only be used with model_version='n1.7'.") + + if self.model_version == GROOT_N1_7: + if self.max_state_dim == 64: + self.max_state_dim = 132 + if self.max_action_dim == 32: + self.max_action_dim = 132 + if self.chunk_size == 50: + self.chunk_size = 40 + if self.n_action_steps == 50: + self.n_action_steps = 40 + if tuple(self.image_size) == (224, 224): + self.image_size = (256, 256) + + inferred_version = infer_groot_model_version(self.base_model_path) + if inferred_version is not None and inferred_version != self.model_version: + raise ValueError( + f"GR00T model_version '{self.model_version}' does not match base_model_path " + f"'{self.base_model_path}', which looks like '{inferred_version}'." + ) + super().__post_init__() if self.n_action_steps > self.chunk_size: @@ -192,7 +513,12 @@ class GrootConfig(PreTrainedConfig): @property def action_delta_indices(self) -> list[int]: """Return indices for delta actions.""" - return list(range(min(self.chunk_size, 16))) + model_action_horizon = 16 + if self.model_version == GROOT_N1_7: + model_action_horizon = ( + infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 + ) + return list(range(min(self.chunk_size, model_action_horizon))) @property def reward_delta_indices(self) -> None: diff --git a/src/lerobot/policies/groot/groot_n1_7.py b/src/lerobot/policies/groot/groot_n1_7.py new file mode 100644 index 000000000..103c03a58 --- /dev/null +++ b/src/lerobot/policies/groot/groot_n1_7.py @@ -0,0 +1,962 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib +import json +import logging +from contextlib import suppress +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn.functional as F # noqa: N812 +from huggingface_hub import snapshot_download +from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError +from torch import nn +from torch.distributions import Beta + +from lerobot.utils.import_utils import _transformers_available, require_package + +from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel + from transformers.feature_extraction_utils import BatchFeature +else: + AutoConfig = None + AutoModel = None + PretrainedConfig = object + PreTrainedModel = object + BatchFeature = None + +try: + import tree +except ImportError: + tree = None + +try: + from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration +except ImportError: + Qwen3VLConfig = None + Qwen3VLForConditionalGeneration = None + +logger = logging.getLogger(__name__) + + +def _copy_default(value: Any) -> Any: + return deepcopy(value) + + +GR00T_N1_7_DEFAULTS: dict[str, Any] = { + "model_dtype": "bfloat16", + "dtype": "bfloat16", + "model_name": "nvidia/Cosmos-Reason2-2B", + "backbone_model_type": "qwen", + "model_revision": None, + "tune_top_llm_layers": 0, + "backbone_embedding_dim": 2048, + "tune_llm": False, + "tune_visual": False, + "select_layer": 12, + "reproject_vision": False, + "use_flash_attention": True, + "load_bf16": False, + "backbone_trainable_params_fp32": True, + "image_crop_size": (230, 230), + "image_target_size": (256, 256), + "shortest_image_edge": None, + "crop_fraction": None, + "random_rotation_angle": None, + "color_jitter_params": None, + "use_albumentations_transforms": True, + "extra_augmentation_config": None, + "formalize_language": True, + "apply_sincos_state_encoding": False, + "use_percentiles": True, + "use_relative_action": False, + "max_state_dim": 132, + "max_action_dim": 132, + "action_horizon": 40, + "hidden_size": 1024, + "input_embedding_dim": 1536, + "state_history_length": 1, + "add_pos_embed": True, + "attn_dropout": 0.2, + "use_vlln": True, + "max_seq_len": 1024, + "use_alternate_vl_dit": True, + "attend_text_every_n_blocks": 2, + "diffusion_model_cfg": { + "positional_embeddings": None, + "num_layers": 32, + "num_attention_heads": 32, + "attention_head_dim": 48, + "norm_type": "ada_norm", + "dropout": 0.2, + "final_dropout": True, + "output_dim": 1024, + "interleave_self_attention": True, + }, + "vl_self_attention_cfg": { + "positional_embeddings": None, + "num_layers": 4, + "num_attention_heads": 32, + "attention_head_dim": 64, + "dropout": 0.2, + "final_dropout": True, + }, + "num_inference_timesteps": 4, + "noise_beta_alpha": 1.5, + "noise_beta_beta": 1.0, + "noise_s": 0.999, + "num_timestep_buckets": 1000, + "tune_projector": True, + "tune_diffusion_model": True, + "tune_vlln": True, + "state_dropout_prob": 0.2, + "exclude_state": False, + "use_mean_std": False, + "max_num_embodiments": 32, + "rtc_ramp_rate": 6.0, +} + + +class GR00TN17Config(PretrainedConfig): + """Configuration for NVIDIA GR00T N1.7. + + N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment + flow-matching action head. This mirrors the public N1.7 checkpoint config + while keeping it local to LeRobot and independent from the external + Isaac-GR00T ``gr00t`` Python package. + """ + + model_type = "Gr00tN1d7" + + _defaults = GR00T_N1_7_DEFAULTS + + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in GR00T_N1_7_DEFAULTS.items(): + setattr(self, key, _copy_default(kwargs.pop(key, value))) + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]: + cfg = self.to_dict() + if not exclude_augment: + return cfg + exclude_keys = { + "random_rotation_angle", + "color_jitter_params", + "use_albumentations_transforms", + "formalize_language", + "image_crop_size", + "image_target_size", + "shortest_image_edge", + "crop_fraction", + } + return {k: v for k, v in cfg.items() if k not in exclude_keys} + + def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str: + return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs) + + +class CategorySpecificLinear(nn.Module): + """Linear layer with category-specific weights for multi-embodiment support.""" + + def __init__(self, num_categories: int, input_dim: int, hidden_dim: int): + super().__init__() + self.num_categories = num_categories + self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim)) + self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim)) + + def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + selected_w = self.W[cat_ids] + selected_b = self.b[cat_ids] + return torch.bmm(x, selected_w) + selected_b.unsqueeze(1) + + +class CategorySpecificMLP(nn.Module): + """Two-layer MLP with category-specific weights.""" + + def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim) + self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim) + + def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + hidden = F.relu(self.layer1(x, cat_ids)) + return self.layer2(hidden, cat_ids) + + +class SinusoidalPositionalEncoding(nn.Module): + """Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``. + + The frequency scalar is intentionally created on CPU and then broadcast with + the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep + embedding and avoids tiny dtype/device construction differences in parity + tests. + """ + + def __init__(self, embedding_dim: int): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * ( + torch.log(torch.tensor(10000.0)) / half_dim + ) + freqs = timesteps.unsqueeze(-1) * exponent.exp() + return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1) + + +def swish(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class MultiEmbodimentActionEncoder(nn.Module): + """Action encoder with category-specific projections and sinusoidal time encoding.""" + + def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int): + super().__init__() + self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) + self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) + self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor: + batch_size, horizon, _ = actions.shape + if timesteps.dim() != 1 or timesteps.shape[0] != batch_size: + raise ValueError("Expected `timesteps` to have shape (B,).") + timesteps = timesteps.unsqueeze(1).expand(-1, horizon) + action_emb = self.W1(actions, cat_ids) + time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype) + x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids)) + return self.W3(x, cat_ids) + + +class Qwen3Backbone(nn.Module): + """Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7. + + The public checkpoint stores the action head in the GR00T checkpoint but + uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper + keeps the nested HF module layout compatible across transformer versions + and exposes the hidden states consumed by the action head. + """ + + def __init__( + self, + model_name: str = "nvidia/Cosmos-Reason2-2B", + tune_llm: bool = False, + tune_visual: bool = False, + select_layer: int = -1, + reproject_vision: bool = False, + use_flash_attention: bool = False, + load_bf16: bool = False, + tune_top_llm_layers: int = 0, + trainable_params_fp32: bool = False, + transformers_loading_kwargs: dict[str, Any] | None = None, + load_pretrained_weights: bool = True, + ): + if Qwen3VLForConditionalGeneration is None: + raise ImportError( + "Qwen3VLForConditionalGeneration is required for GR00T N1.7. " + "Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` " + "or use a transformers version that provides Qwen3-VL support." + ) + + super().__init__() + transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True} + + extra_kwargs: dict[str, Any] = {} + if use_flash_attention: + try: + import flash_attn # noqa: F401 + + extra_kwargs["attn_implementation"] = "flash_attention_2" + except ImportError: + logger.warning("flash_attn is not installed. Falling back to SDPA attention.") + extra_kwargs["attn_implementation"] = "sdpa" + if load_bf16: + extra_kwargs["torch_dtype"] = torch.bfloat16 + + if load_pretrained_weights: + self.model = Qwen3VLForConditionalGeneration.from_pretrained( + model_name, + **extra_kwargs, + **transformers_loading_kwargs, + ).eval() + else: + self.model = self._from_backbone_config( + model_name=model_name, + model_kwargs=extra_kwargs, + config_kwargs=transformers_loading_kwargs, + ).eval() + + while len(self.language_model.layers) > select_layer: + self.language_model.layers.pop(-1) + + self.select_layer = select_layer + self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers) + if load_bf16 and trainable_params_fp32: + for parameter in self.parameters(): + if parameter.requires_grad: + parameter.data = parameter.data.to(torch.float32) + + def set_trainable_parameters( + self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0 + ) -> None: + self.tune_llm = tune_llm + self.tune_visual = tune_visual + for parameter in self.parameters(): + parameter.requires_grad = True + if not tune_llm: + self.language_model.requires_grad_(False) + if not tune_visual: + self.visual.requires_grad_(False) + if tune_top_llm_layers > 0: + for layer in self.language_model.layers[-tune_top_llm_layers:]: + for parameter in layer.parameters(): + parameter.requires_grad = True + + def set_frozen_modules_to_eval_mode(self) -> None: + if self.training: + if self.language_model and not self.tune_llm: + self.language_model.eval() + if self.visual and not self.tune_visual: + self.visual.eval() + + @property + def language_model(self) -> nn.Module: + return getattr(self.model, "model", self.model).language_model + + @property + def visual(self) -> nn.Module: + return getattr(self.model, "model", self.model).visual + + def _from_backbone_config( + self, + *, + model_name: str, + model_kwargs: dict[str, Any], + config_kwargs: dict[str, Any], + ) -> nn.Module: + if _is_cosmos_reason2_backbone(model_name): + backbone_config = _cosmos_reason2_qwen3_vl_config() + else: + if AutoConfig is None: + raise ImportError( + "AutoConfig is required to initialize a GR00T N1.7 backbone from config. " + "Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`." + ) + backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs) + return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs) + + def prepare_input(self, batch: dict[str, Any]) -> BatchFeature: + return BatchFeature(data=batch) + + def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None: + if "mm_token_type_ids" in model_input: + return + if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input: + return + + input_ids = model_input.get("input_ids") + if input_ids is None: + return + + mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device) + image_token_id = getattr(self.model.config, "image_token_id", None) + video_token_id = getattr(self.model.config, "video_token_id", None) + if image_token_id is not None: + mm_token_type_ids[input_ids == image_token_id] = 1 + if video_token_id is not None: + mm_token_type_ids[input_ids == video_token_id] = 2 + + model_input["mm_token_type_ids"] = mm_token_type_ids + + def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None: + """Restore the Qwen3-VL text position ids used by older Transformers releases. + + Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then + drops text position ids before calling text-layer flash attention. GR00T + N1.7 was aligned against the older Transformers path, where a fourth text + position row is forwarded alongside the temporal/height/width rows. Adding + the row here preserves the newer multimodal position computation while + keeping flash attention on the legacy code path. + """ + + if "position_ids" in model_input: + return + + qwen3_model = getattr(self.model, "model", self.model) + compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None) + if compute_3d_position_ids is None: + return + + position_ids = compute_3d_position_ids( + input_ids=model_input.get("input_ids"), + image_grid_thw=model_input.get("image_grid_thw"), + video_grid_thw=model_input.get("video_grid_thw"), + inputs_embeds=None, + attention_mask=model_input.get("attention_mask"), + past_key_values=None, + mm_token_type_ids=model_input.get("mm_token_type_ids"), + ) + if position_ids.ndim == 3 and position_ids.shape[0] == 3: + position_ids = torch.cat([position_ids[:1], position_ids], dim=0) + + model_input["position_ids"] = position_ids + + def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor: + """Return the pre-final-norm decoder output consumed by the N1.7 action head. + + Older Transformers releases exposed this tensor as ``hidden_states[-1]``. + Newer releases expose the post-final-norm tensor there instead. Capturing + the last decoder layer output directly keeps the N1.7 action head input + stable across Transformers versions. + """ + + captured: dict[str, torch.Tensor] = {} + + def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None: + if isinstance(output, torch.Tensor): + captured["features"] = output + elif isinstance(output, (tuple, list)) and output: + captured["features"] = output[0] + elif hasattr(output, "last_hidden_state"): + captured["features"] = output.last_hidden_state + + hook = self.language_model.layers[-1].register_forward_hook(capture_output) + try: + outputs = self.model(**model_input, output_hidden_states=True) + finally: + hook.remove() + + return captured.get("features", outputs.hidden_states[-1]) + + def forward(self, vl_input: BatchFeature) -> BatchFeature: + self.set_frozen_modules_to_eval_mode() + keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"] + optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"] + model_input = {key: vl_input[key] for key in keys_to_use} + model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input}) + self._ensure_mm_token_type_ids(model_input) + self._ensure_legacy_qwen3_position_ids(model_input) + features = self._last_decoder_layer_output(model_input) + image_mask = model_input["input_ids"] == self.model.config.image_token_id + attention_mask = model_input["attention_mask"] == 1 + return BatchFeature( + data={ + "backbone_features": features, + "backbone_attention_mask": attention_mask, + "image_mask": image_mask, + } + ) + + +class GR00TN17ActionHead(nn.Module): + supports_gradient_checkpointing = True + + def __init__(self, config: GR00TN17Config): + require_package("diffusers", extra="groot") + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.input_embedding_dim = config.input_embedding_dim + + if config.use_alternate_vl_dit: + self.model = AlternateVLDiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + attend_text_every_n_blocks=config.attend_text_every_n_blocks, + ) + else: + self.model = DiT( + **config.diffusion_model_cfg, + cross_attention_dim=config.backbone_embedding_dim, + ) + + self.action_dim = config.max_action_dim + self.action_horizon = config.action_horizon + self.num_inference_timesteps = config.num_inference_timesteps + self.state_encoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=config.max_state_dim * config.state_history_length, + hidden_dim=self.hidden_size, + output_dim=self.input_embedding_dim, + ) + self.action_encoder = MultiEmbodimentActionEncoder( + action_dim=self.action_dim, + hidden_size=self.input_embedding_dim, + num_embodiments=config.max_num_embodiments, + ) + self.action_decoder = CategorySpecificMLP( + num_categories=config.max_num_embodiments, + input_dim=self.hidden_size, + hidden_dim=self.hidden_size, + output_dim=self.action_dim, + ) + self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity() + vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None) + if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0: + self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg) + else: + self.vl_self_attention = nn.Identity() + if config.add_pos_embed: + self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + self.state_dropout_prob = config.state_dropout_prob + self._noise_beta_alpha = config.noise_beta_alpha + self._noise_beta_beta = config.noise_beta_beta + self._beta_dist = None + self.num_timestep_buckets = config.num_timestep_buckets + self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln) + + def set_trainable_parameters( + self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool + ) -> None: + self.tune_projector = tune_projector + self.tune_diffusion_model = tune_diffusion_model + self.tune_vlln = tune_vlln + for parameter in self.parameters(): + parameter.requires_grad = True + if not tune_projector: + self.state_encoder.requires_grad_(False) + self.action_encoder.requires_grad_(False) + self.action_decoder.requires_grad_(False) + if self.config.add_pos_embed: + self.position_embedding.requires_grad_(False) + if not tune_diffusion_model: + self.model.requires_grad_(False) + if not tune_vlln: + self.vlln.requires_grad_(False) + self.vl_self_attention.requires_grad_(False) + + def set_frozen_modules_to_eval_mode(self) -> None: + if self.training: + if not self.tune_projector: + self.state_encoder.eval() + self.action_encoder.eval() + self.action_decoder.eval() + if self.config.add_pos_embed: + self.position_embedding.eval() + if not self.tune_diffusion_model: + self.model.eval() + if not self.tune_vlln: + self.vlln.eval() + self.vl_self_attention.eval() + + def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + if self._beta_dist is None: + beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32) + beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32) + self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False) + sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype) + return (1 - sample) * self.config.noise_s + + def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature: + backbone_features = self.vlln(backbone_output["backbone_features"]) + backbone_output["backbone_features"] = self.vl_self_attention(backbone_features) + return backbone_output + + def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + self.set_frozen_modules_to_eval_mode() + backbone_output = self.process_backbone_output(backbone_output) + vl_embeds = backbone_output.backbone_features + device = vl_embeds.device + embodiment_id = action_input.embodiment_id + + if action_input.state.shape[1] != self.config.state_history_length: + raise ValueError("state history length does not match GR00T N1.7 config.") + state = action_input.state.view(action_input.state.shape[0], 1, -1) + state_features = self.state_encoder(state, embodiment_id) + + if self.training and self.state_dropout_prob > 0: + do_dropout = ( + torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob + ) + state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype)) + + actions = action_input.action + noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype) + t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype) + t = t[:, None, None] + noisy_trajectory = (1 - t) * noise + t * actions + velocity = actions - noise + t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long() + action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id) + + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0) + + sa_embs = torch.cat((state_features, action_features), dim=1) + if self.config.use_alternate_vl_dit: + model_output, _ = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + encoder_attention_mask=backbone_output.backbone_attention_mask, + timestep=t_discretized, + return_all_hidden_states=True, + image_mask=backbone_output.image_mask, + backbone_attention_mask=backbone_output.backbone_attention_mask, + ) + else: + model_output, _ = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + encoder_attention_mask=backbone_output.backbone_attention_mask, + timestep=t_discretized, + return_all_hidden_states=True, + ) + + pred = self.action_decoder(model_output, embodiment_id) + pred_actions = pred[:, -actions.shape[1] :] + action_mask = action_input.action_mask.to(dtype=pred_actions.dtype) + action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask + loss = action_loss.sum() / (action_mask.sum() + 1e-6) + return BatchFeature( + data={ + "loss": loss, + "action_loss": action_loss, + "action_mask": action_mask, + "backbone_features": vl_embeds, + "state_features": state_features, + } + ) + + def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature: + backbone_output = self.process_backbone_output(backbone_output) + state = action_input.state + if state.shape[1] != self.config.state_history_length: + raise ValueError("state history length does not match GR00T N1.7 config.") + state = state.view(state.shape[0], 1, -1) + state_features = self.state_encoder(state, action_input.embodiment_id) + return BatchFeature( + data={"backbone_features": backbone_output.backbone_features, "state_features": state_features} + ) + + @torch.no_grad() + def get_action_with_features( + self, + backbone_features: torch.Tensor, + state_features: torch.Tensor, + embodiment_id: torch.Tensor, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + vl_embeds = backbone_features + batch_size = vl_embeds.shape[0] + device = vl_embeds.device + actions = torch.randn( + size=(batch_size, self.config.action_horizon, self.action_dim), + dtype=vl_embeds.dtype, + device=device, + ) + dt = 1.0 / self.num_inference_timesteps + vel_strength = torch.ones_like(actions) + + if "action" in action_input: + if options is None: + raise ValueError("RTC options are required when action is provided to get_action.") + action_horizon_before_padding = options["action_horizon"] + actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][ + :, + action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding, + :, + ] + vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0 + intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"] + t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device) + ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t) + ramp = ramp / ramp[-1].clamp_min(1e-8) + vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][ + None, :, None + ].to(device) + + for t_step in range(self.num_inference_timesteps): + t_cont = t_step / float(self.num_inference_timesteps) + t_discretized = int(t_cont * self.num_timestep_buckets) + timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device) + action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id) + if self.config.add_pos_embed: + pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device) + action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0) + sa_embs = torch.cat((state_features, action_features), dim=1) + + if self.config.use_alternate_vl_dit: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + image_mask=backbone_output.image_mask, + backbone_attention_mask=backbone_output.backbone_attention_mask, + ) + else: + model_output = self.model( + hidden_states=sa_embs, + encoder_hidden_states=vl_embeds, + timestep=timesteps_tensor, + ) + pred = self.action_decoder(model_output, embodiment_id) + actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength + + return BatchFeature( + data={ + "action_pred": actions, + "backbone_features": vl_embeds, + "state_features": state_features, + } + ) + + @torch.no_grad() + def get_action( + self, + backbone_output: BatchFeature, + action_input: BatchFeature, + options: dict[str, Any] | None = None, + ) -> BatchFeature: + features = self._encode_features(backbone_output, action_input) + return self.get_action_with_features( + backbone_features=features.backbone_features, + state_features=features.state_features, + embodiment_id=action_input.embodiment_id, + backbone_output=backbone_output, + action_input=action_input, + options=options, + ) + + @property + def device(self) -> torch.device: + return next(iter(self.parameters())).device + + @property + def dtype(self) -> torch.dtype: + return next(iter(self.parameters())).dtype + + def prepare_input(self, batch: dict[str, Any]) -> BatchFeature: + return BatchFeature(data=batch) + + +def _is_cosmos_reason2_backbone(model_name: str) -> bool: + return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B" + + +def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig: + if Qwen3VLConfig is None: + raise ImportError( + "Qwen3VLConfig is required for GR00T N1.7. " + "Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`." + ) + return Qwen3VLConfig( + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + vision_end_token_id=151653, + tie_word_embeddings=True, + text_config={ + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-6, + "rope_scaling": { + "mrope_interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + "rope_theta": 5000000, + "tie_word_embeddings": True, + "use_cache": True, + "vocab_size": 151936, + }, + vision_config={ + "deepstack_visual_indexes": [5, 11, 17], + "depth": 24, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1024, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "qwen3_vl", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 2048, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + ) + + +def get_backbone_cls(config: GR00TN17Config): + if ( + config.backbone_model_type == "qwen" + or "nvidia/Cosmos-Reason2" in config.model_name + or "Qwen/Qwen3-VL" in config.model_name + ): + return Qwen3Backbone + raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}") + + +class GR00TN17(PreTrainedModel): + """GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone.""" + + config_class = GR00TN17Config + supports_gradient_checkpointing = True + + def __init__( + self, + config: GR00TN17Config, + transformers_loading_kwargs: dict[str, Any] | None = None, + load_backbone_weights: bool = True, + ): + super().__init__(config) + transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True} + self.config = config + backbone_cls = get_backbone_cls(config) + self.backbone = backbone_cls( + model_name=config.model_name, + tune_llm=config.tune_llm, + tune_visual=config.tune_visual, + select_layer=config.select_layer, + reproject_vision=config.reproject_vision, + use_flash_attention=config.use_flash_attention, + load_bf16=config.load_bf16, + tune_top_llm_layers=config.tune_top_llm_layers, + trainable_params_fp32=config.backbone_trainable_params_fp32, + transformers_loading_kwargs=transformers_loading_kwargs, + load_pretrained_weights=load_backbone_weights, + ) + self.action_head = GR00TN17ActionHead(config) + self.post_init() + + def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]: + global tree + if tree is None: + require_package("dm-tree", extra="groot", import_name="tree") + tree = importlib.import_module("tree") + backbone_inputs = self.backbone.prepare_input(inputs) + action_inputs = self.action_head.prepare_input(inputs) + + def to_device_with_dtype(x): + if not isinstance(x, torch.Tensor): + return x + if torch.is_floating_point(x): + return x.to(self.device, dtype=self.dtype) + return x.to(self.device) + + return ( + tree.map_structure(to_device_with_dtype, backbone_inputs), + tree.map_structure(to_device_with_dtype, action_inputs), + ) + + def forward(self, inputs: dict[str, Any]) -> BatchFeature: + backbone_inputs, action_inputs = self.prepare_input(inputs) + backbone_outputs = self.backbone(backbone_inputs) + return self.action_head(backbone_outputs, action_inputs) + + def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature: + backbone_inputs, action_inputs = self.prepare_input(inputs) + backbone_outputs = self.backbone(backbone_inputs) + return self.action_head.get_action(backbone_outputs, action_inputs, options) + + @property + def device(self) -> torch.device: + return next(iter(self.parameters())).device + + @property + def dtype(self) -> torch.dtype: + return next(iter(self.parameters())).dtype + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + tune_visual = kwargs.pop("tune_visual", True) + tune_llm = kwargs.pop("tune_llm", False) + tune_projector = kwargs.pop("tune_projector", True) + tune_diffusion_model = kwargs.pop("tune_diffusion_model", True) + tune_vlln = kwargs.pop("tune_vlln", True) + transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or { + "trust_remote_code": True + } + load_backbone_weights = kwargs.pop("load_backbone_weights", False) + for key in ("revision", "cache_dir", "local_files_only", "token"): + if key in kwargs: + transformers_loading_kwargs.setdefault(key, kwargs[key]) + + try: + local_model_path = snapshot_download( + pretrained_model_name_or_path, + repo_type="model", + revision=kwargs.get("revision"), + cache_dir=kwargs.get("cache_dir"), + local_files_only=kwargs.get("local_files_only", False), + token=kwargs.get("token"), + ) + except (HFValidationError, RepositoryNotFoundError): + local_model_path = pretrained_model_name_or_path + + pretrained_model = super().from_pretrained( + local_model_path, + transformers_loading_kwargs=transformers_loading_kwargs, + load_backbone_weights=load_backbone_weights, + **kwargs, + ) + pretrained_model.backbone.set_trainable_parameters( + tune_visual=tune_visual, + tune_llm=tune_llm, + tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers, + ) + pretrained_model.action_head.set_trainable_parameters( + tune_projector=tune_projector, + tune_diffusion_model=tune_diffusion_model, + tune_vlln=tune_vlln, + ) + return pretrained_model + + +def _register_with_transformers() -> None: + if AutoConfig is None or AutoModel is None: + return + try: + AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True) + except TypeError: + with suppress(ValueError): + AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config) + try: + AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True) + except TypeError: + with suppress(ValueError): + AutoModel.register(GR00TN17Config, GR00TN17) + + +_register_with_transformers() diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 2e2e9ca89..a28d0c148 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -46,7 +46,15 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES from lerobot.utils.import_utils import require_package from ..pretrained import PreTrainedPolicy -from .configuration_groot import GrootConfig +from .configuration_groot import ( + GROOT_N1_5, + GROOT_N1_7, + GrootConfig, + infer_groot_model_version, + infer_groot_n1_7_action_execution_horizon, + infer_groot_n1_7_action_horizon, + normalize_groot_model_version, +) from .groot_n1 import GR00TN15 T = TypeVar("T", bound="GrootPolicy") @@ -67,6 +75,7 @@ class GrootPolicy(PreTrainedPolicy): # Initialize GR00T model using ported components self._groot_model = self._create_groot_model() + self._action_queue_steps = self._resolve_action_queue_steps() self.reset() @@ -82,13 +91,23 @@ class GrootPolicy(PreTrainedPolicy): # Handle Flash Attention compatibility issues self._handle_flash_attention_compatibility() - model = GR00TN15.from_pretrained( - pretrained_model_name_or_path=self.config.base_model_path, - tune_llm=self.config.tune_llm, - tune_visual=self.config.tune_visual, - tune_projector=self.config.tune_projector, - tune_diffusion_model=self.config.tune_diffusion_model, - ) + model_kwargs = { + "pretrained_model_name_or_path": self.config.base_model_path, + "tune_llm": self.config.tune_llm, + "tune_visual": self.config.tune_visual, + "tune_projector": self.config.tune_projector, + "tune_diffusion_model": self.config.tune_diffusion_model, + } + if self.config.model_version == GROOT_N1_7: + from .groot_n1_7 import GR00TN17 + + model = GR00TN17.from_pretrained( + **model_kwargs, + tune_vlln=True, + transformers_loading_kwargs={"trust_remote_code": True}, + ) + else: + model = GR00TN15.from_pretrained(**model_kwargs) model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype model.config.compute_dtype = model.compute_dtype @@ -97,7 +116,7 @@ class GrootPolicy(PreTrainedPolicy): def reset(self): """Reset policy state when environment resets.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) + self._action_queue = deque([], maxlen=self._action_queue_steps) @classmethod def from_pretrained( @@ -141,8 +160,13 @@ class GrootPolicy(PreTrainedPolicy): from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError + requested_version = ( + normalize_groot_model_version(config.model_version) + if config is not None + else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5 + ) print( - "The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n" + f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n" f"Loading pretrained model from: {pretrained_name_or_path}" ) @@ -193,8 +217,12 @@ class GrootPolicy(PreTrainedPolicy): print("Detected base GR00T model, loading from HuggingFace...") if config is None: + model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5 # Create default config with the pretrained path - config = GrootConfig(base_model_path=str(pretrained_name_or_path)) + config = GrootConfig( + model_version=model_version, + base_model_path=str(pretrained_name_or_path), + ) # Add minimal visual feature required for validation # validate_features() will automatically add state and action features @@ -215,6 +243,25 @@ class GrootPolicy(PreTrainedPolicy): if hasattr(config, key): setattr(config, key, value) + config.model_version = normalize_groot_model_version(config.model_version) + inferred_version = infer_groot_model_version(config.base_model_path) + if inferred_version is not None and inferred_version != config.model_version: + raise ValueError( + f"GR00T model_version '{config.model_version}' does not match base_model_path " + f"'{config.base_model_path}', which looks like '{inferred_version}'." + ) + if config.model_version == GROOT_N1_7: + if config.max_state_dim == 64: + config.max_state_dim = 132 + if config.max_action_dim == 32: + config.max_action_dim = 132 + if config.chunk_size == 50: + config.chunk_size = 40 + if config.n_action_steps == 50: + config.n_action_steps = 40 + if tuple(config.image_size) == (224, 224): + config.image_size = (256, 256) + # Create a fresh policy instance - this will automatically load the GR00T model # in __init__ via _create_groot_model() policy = cls(config) @@ -225,18 +272,59 @@ class GrootPolicy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + def _resolve_action_queue_steps(self) -> int: + n_action_steps = int(self.config.n_action_steps) + if self.config.model_version != GROOT_N1_7: + return n_action_steps + + checkpoint_action_horizon = infer_groot_n1_7_action_horizon( + self.config.base_model_path, + self.config.embodiment_tag, + ) + execution_horizon = infer_groot_n1_7_action_execution_horizon( + self.config.base_model_path, + self.config.embodiment_tag, + ) + horizons = [n_action_steps] + if checkpoint_action_horizon is not None: + horizons.append(checkpoint_action_horizon) + if execution_horizon is not None: + horizons.append(execution_horizon) + return min(horizons) + + def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]: + allowed_base = {"state", "state_mask", "embodiment_id"} + if include_action: + allowed_base.update({"action", "action_mask"}) + + if self.config.model_version == GROOT_N1_7: + allowed_base.update( + { + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "mm_token_type_ids", + "pixel_values_videos", + "video_grid_thw", + } + ) + allowed_base.add("action_mask") + else: + allowed_base.update({"action_mask"} if include_action else set()) + + return { + k: v + for k, v in batch.items() + if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info") + } + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Training forward pass. Delegates to Isaac-GR00T model.forward when inputs are compatible. """ - # Build a clean input dict for GR00T: keep only tensors GR00T consumes - allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"} - groot_inputs = { - k: v - for k, v in batch.items() - if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info") - } + groot_inputs = self._filter_groot_inputs(batch, include_action=True) # Get device from model parameters device = next(self.parameters()).device @@ -261,15 +349,10 @@ class GrootPolicy(PreTrainedPolicy): """ self.eval() - # Build a clean input dict for GR00T: keep only tensors GR00T consumes - # Preprocessing is handled by the processor pipeline, so we just filter the batch - # NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting) - allowed_base = {"state", "state_mask", "embodiment_id"} - groot_inputs = { - k: v - for k, v in batch.items() - if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info") - } + # Preprocessing is handled by the processor pipeline, so we just filter the batch. + # During inference, we do not pass action because it is predicted. + # N1.7 still carries a 2-D action horizon mask from its checkpoint processor. + groot_inputs = self._filter_groot_inputs(batch, include_action=False) # Get device from model parameters device = next(self.parameters()).device @@ -292,7 +375,7 @@ class GrootPolicy(PreTrainedPolicy): if len(self._action_queue) == 0: actions = self.predict_action_chunk(batch) - self._action_queue.extend(actions.transpose(0, 1)) + self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1)) return self._action_queue.popleft() # ------------------------- diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 6848c7c84..94f895c51 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Any import numpy as np @@ -57,11 +59,403 @@ from lerobot.utils.constants import ( POLICY_PREPROCESSOR_DEFAULT_NAME, ) -from .configuration_groot import GrootConfig +from .configuration_groot import ( + GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + GROOT_N1_7, + GROOT_N1_7_BACKBONE_MODEL, + GrootConfig, + is_raw_groot_n1_7_checkpoint, +) # Defaults for Eagle processor locations DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" +N1_7_EMBODIMENT_MAPPING = { + "oxe_droid_relative_eef_relative_joint": 24, + "xdof_relative_eef_relative_joint": 27, + "xdof_relative_eef_relative_joint_subtask": 27, + "real_g1_relative_eef_relative_joints": 25, + "real_r1_pro_sharpa_relative_eef": 26, + "real_r1_pro_sharpa_relative_eef_human": 26, + "real_r1_pro_sharpa_relative_eef_maxinsights": 26, + "real_r1_pro_sharpa_relative_eef_mecka": 26, + "unitree_g1_full_body_with_waist_height_nav_cmd": 25, + "simpler_env_google": 0, + "simpler_env_widowx": 1, + "libero_sim": 2, + "new_embodiment": 10, +} + +_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {} + + +def _n1_7_state_cache_key(value: str | None) -> str: + return value or "groot_n1_7_default" + + +@dataclass +class _GrootN17CheckpointProcessorAssets: + """Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint. + + Public N1.7 checkpoints store preprocessing and action-decoding choices next + to the model weights. Keeping those values together avoids falling back to + LeRobot defaults that are valid for older GR00T variants but change N1.7 + inputs or decoded actions. + """ + + stats: dict[str, dict[str, Any]] + raw_stats: dict[str, Any] + modality_config: dict[str, Any] + embodiment_mapping: dict[str, int] + formalize_language: bool + valid_action_horizon: int | None + max_action_horizon: int | None + video_horizon: int | None + use_percentiles: bool + use_relative_action: bool + clip_outliers: bool + video_modality_keys: list[str] | None + image_crop_size: list[int] | None + image_target_size: list[int] | None + shortest_image_edge: int | None + crop_fraction: float | None + use_albumentations: bool + + +def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None: + """Load N1.7 processor settings from checkpoint sidecar JSON files. + + Returns ``None`` for non-raw N1.7 checkpoints so the generic GR00T pipeline + can keep using caller-provided dataset stats and config values. + """ + + if not is_raw_groot_n1_7_checkpoint(config.base_model_path): + return None + + checkpoint_path = Path(config.base_model_path).expanduser() + processor_config = _read_json(checkpoint_path / "processor_config.json") + processor_kwargs = processor_config.get("processor_kwargs", {}) + if not isinstance(processor_kwargs, dict): + processor_kwargs = {} + + all_stats = _read_json(checkpoint_path / "statistics.json") + raw_stats = all_stats.get(config.embodiment_tag) + if not isinstance(raw_stats, dict): + raw_stats = {} + + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + modality_configs = {} + modality_config = modality_configs.get(config.embodiment_tag) + if not isinstance(modality_config, dict): + modality_config = {} + + use_relative_action = bool(processor_kwargs.get("use_relative_action", False)) + stats = _load_n1_7_checkpoint_stats( + checkpoint_path, + processor_kwargs, + config.embodiment_tag, + raw_stats=raw_stats, + modality_config=modality_config, + use_relative_action=use_relative_action, + ) + embodiment_mapping = _load_n1_7_embodiment_mapping(checkpoint_path) or dict(N1_7_EMBODIMENT_MAPPING) + formalize_language = processor_kwargs.get("formalize_language", True) + if not isinstance(formalize_language, bool): + formalize_language = True + clip_outliers = processor_kwargs.get("clip_outliers", True) + if not isinstance(clip_outliers, bool): + clip_outliers = True + use_albumentations = processor_kwargs.get("use_albumentations", False) + if not isinstance(use_albumentations, bool): + use_albumentations = False + + valid_action_horizon = _load_n1_7_checkpoint_action_horizon(processor_kwargs, config.embodiment_tag) + video_horizon = _load_n1_7_checkpoint_video_horizon(processor_kwargs, config.embodiment_tag) + video_modality_keys = _load_n1_7_checkpoint_video_modality_keys(processor_kwargs, config.embodiment_tag) + max_action_horizon = processor_kwargs.get("max_action_horizon") + if not isinstance(max_action_horizon, int): + max_action_horizon = None + + return _GrootN17CheckpointProcessorAssets( + stats=stats, + raw_stats=raw_stats, + modality_config=modality_config, + embodiment_mapping=embodiment_mapping, + formalize_language=formalize_language, + valid_action_horizon=valid_action_horizon, + max_action_horizon=max_action_horizon, + video_horizon=video_horizon, + use_percentiles=bool(processor_kwargs.get("use_percentiles", False)), + use_relative_action=use_relative_action, + clip_outliers=clip_outliers, + video_modality_keys=video_modality_keys, + image_crop_size=_as_int_pair(processor_kwargs.get("image_crop_size")), + image_target_size=_as_int_pair(processor_kwargs.get("image_target_size")), + shortest_image_edge=_as_optional_int(processor_kwargs.get("shortest_image_edge")), + crop_fraction=_as_optional_float(processor_kwargs.get("crop_fraction")), + use_albumentations=use_albumentations, + ) + + +def _read_json(path: Path) -> dict[str, Any]: + try: + with path.open() as f: + data = json.load(f) + except (OSError, json.JSONDecodeError): + return {} + return data if isinstance(data, dict) else {} + + +def _load_n1_7_embodiment_mapping(checkpoint_path: Path) -> dict[str, int] | None: + mapping = _read_json(checkpoint_path / "embodiment_id.json") + if not mapping: + return None + parsed: dict[str, int] = {} + for key, value in mapping.items(): + if not isinstance(key, str): + continue + try: + parsed[key] = int(value) + except (TypeError, ValueError): + continue + return parsed or None + + +def _load_n1_7_checkpoint_stats( + checkpoint_path: Path, + processor_kwargs: dict[str, Any], + embodiment_tag: str, + *, + raw_stats: dict[str, Any] | None = None, + modality_config: dict[str, Any] | None = None, + use_relative_action: bool = False, +) -> dict[str, dict[str, Any]]: + """Convert checkpoint modality-group stats into LeRobot flat tensor stats. + + Isaac-GR00T keeps statistics keyed by semantic groups such as EEF pose and + joints. LeRobot normalizers operate over a single vector, so this function + preserves checkpoint group order while flattening each selected statistic. + """ + + if raw_stats is None: + all_stats = _read_json(checkpoint_path / "statistics.json") + raw_stats = all_stats.get(embodiment_tag) + if not isinstance(raw_stats, dict): + return {} + + if modality_config is None: + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return {} + modality_config = modality_configs.get(embodiment_tag) + if not isinstance(modality_config, dict): + return {} + + use_percentiles = processor_kwargs.get("use_percentiles", False) + return { + OBS_STATE: _flatten_n1_7_modality_stats( + embodiment_stats=raw_stats, + embodiment_config=modality_config, + modality="state", + use_percentiles=bool(use_percentiles), + use_relative_action=use_relative_action, + ), + ACTION: _flatten_n1_7_modality_stats( + embodiment_stats=raw_stats, + embodiment_config=modality_config, + modality="action", + use_percentiles=bool(use_percentiles), + use_relative_action=use_relative_action, + ), + } + + +def _load_n1_7_checkpoint_action_horizon( + processor_kwargs: dict[str, Any], + embodiment_tag: str, +) -> int | None: + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return None + embodiment_config = modality_configs.get(embodiment_tag, {}) + if not isinstance(embodiment_config, dict): + return None + action_config = embodiment_config.get("action", {}) + if not isinstance(action_config, dict): + return None + delta_indices = action_config.get("delta_indices", []) + if not isinstance(delta_indices, list): + return None + return len(delta_indices) or None + + +def _load_n1_7_checkpoint_video_horizon( + processor_kwargs: dict[str, Any], + embodiment_tag: str, +) -> int | None: + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return None + embodiment_config = modality_configs.get(embodiment_tag, {}) + if not isinstance(embodiment_config, dict): + return None + video_config = embodiment_config.get("video", {}) + if not isinstance(video_config, dict): + return None + delta_indices = video_config.get("delta_indices", []) + if not isinstance(delta_indices, list): + return None + return len(delta_indices) or None + + +def _load_n1_7_checkpoint_video_modality_keys( + processor_kwargs: dict[str, Any], + embodiment_tag: str, +) -> list[str] | None: + modality_configs = processor_kwargs.get("modality_configs", {}) + if not isinstance(modality_configs, dict): + return None + embodiment_config = modality_configs.get(embodiment_tag, {}) + if not isinstance(embodiment_config, dict): + return None + video_config = embodiment_config.get("video", {}) + if not isinstance(video_config, dict): + return None + modality_keys = video_config.get("modality_keys", []) + if not isinstance(modality_keys, list): + return None + keys = [key for key in modality_keys if isinstance(key, str)] + return keys or None + + +def _as_int_pair(value: Any) -> list[int] | None: + if not isinstance(value, (list, tuple)) or len(value) != 2: + return None + try: + return [int(value[0]), int(value[1])] + except (TypeError, ValueError): + return None + + +def _as_optional_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _as_optional_float(value: Any) -> float | None: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _flatten_n1_7_modality_stats( + *, + embodiment_stats: dict[str, Any], + embodiment_config: dict[str, Any], + modality: str, + use_percentiles: bool, + use_relative_action: bool, +) -> dict[str, list[float]]: + """Flatten one N1.7 modality's grouped statistics in checkpoint order. + + When checkpoints request percentile normalization, q01/q99 replace min/max + for regular groups. Relative action groups read from ``relative_action`` + stats and keep min/max, matching Isaac-GR00T's processor override. + """ + + source_stats = embodiment_stats.get(modality, {}) + modality_config = embodiment_config.get(modality, {}) + if not isinstance(source_stats, dict) or not isinstance(modality_config, dict): + return {} + modality_keys = modality_config.get("modality_keys", []) + if not isinstance(modality_keys, list): + return {} + + flattened: dict[str, list[float]] = {} + action_configs = modality_config.get("action_configs", []) if modality == "action" else [] + if not isinstance(action_configs, list): + action_configs = [] + relative_stats = embodiment_stats.get("relative_action", {}) + if not isinstance(relative_stats, dict): + relative_stats = {} + + for stat_name in ("min", "max", "mean", "std"): + values: list[float] = [] + source_stat_name = stat_name + if use_percentiles and stat_name == "min": + source_stat_name = "q01" + elif use_percentiles and stat_name == "max": + source_stat_name = "q99" + + for idx, modality_key in enumerate(modality_keys): + if not isinstance(modality_key, str): + continue + key_source_stats = source_stats + key_stat_name = source_stat_name + if modality == "action" and use_relative_action and idx < len(action_configs): + action_config = action_configs[idx] + if isinstance(action_config, dict) and _config_value(action_config.get("rep")) == "relative": + key_source_stats = relative_stats + key_stat_name = stat_name + key_stats = key_source_stats.get(modality_key, {}) + if not isinstance(key_stats, dict): + raise KeyError(f"Missing statistics for {modality}.{modality_key}") + raw_values = key_stats.get(key_stat_name) + if raw_values is None: + raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}") + values.extend(_as_float_list(raw_values)) + if values: + flattened[stat_name] = values + + return flattened + + +def _as_float_list(values: Any) -> list[float]: + if values is None: + return [] + if isinstance(values, torch.Tensor): + return values.detach().cpu().reshape(-1).float().tolist() + if isinstance(values, np.ndarray): + return values.reshape(-1).astype(np.float32).tolist() + if isinstance(values, (list, tuple)): + flattened: list[float] = [] + for value in values: + flattened.extend(_as_float_list(value)) + return flattened + return [float(values)] + + +def _config_value(value: Any) -> str: + if hasattr(value, "value"): + value = value.value + text = str(value).lower() + return { + "relative": "relative", + "absolute": "absolute", + "delta": "delta", + "eef": "eef", + "non_eef": "non_eef", + "default": "default", + "xyz_rot6d": "xyz+rot6d", + "xyz+rot6d": "xyz+rot6d", + "xyz_rotvec": "xyz+rotvec", + "xyz+rotvec": "xyz+rotvec", + }.get(text, text) + + +def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool: + if not stats: + return False + return any(bool(modality_stats) for modality_stats in stats.values()) + def make_groot_pre_post_processors( config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None @@ -94,6 +488,109 @@ def make_groot_pre_post_processors( Tuple of (preprocessor, postprocessor) pipelines """ + if config.model_version == GROOT_N1_7: + checkpoint_assets = _load_n1_7_checkpoint_processor_assets(config) + action_horizon = ( + checkpoint_assets.max_action_horizon + if checkpoint_assets is not None and checkpoint_assets.max_action_horizon is not None + else min(config.chunk_size, 40) + ) + valid_action_horizon = ( + checkpoint_assets.valid_action_horizon + if checkpoint_assets is not None and checkpoint_assets.valid_action_horizon is not None + else action_horizon + ) + checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None + padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {}) + embodiment_mapping = ( + checkpoint_assets.embodiment_mapping + if checkpoint_assets is not None + else dict(N1_7_EMBODIMENT_MAPPING) + ) + formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True + clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True + video_modality_keys = checkpoint_assets.video_modality_keys if checkpoint_assets is not None else None + try: + env_action_dim = int(config.output_features[ACTION].shape[0]) + except Exception: + env_action_dim = 0 + state_cache_key = f"groot_n1_7:{config.embodiment_tag}" + pack_step = GrootN17PackInputsStep( + state_horizon=1, + action_horizon=action_horizon, + valid_action_horizon=valid_action_horizon, + video_horizon=checkpoint_assets.video_horizon if checkpoint_assets is not None else None, + max_state_dim=config.max_state_dim, + max_action_dim=config.max_action_dim, + language_key="task", + formalize_language=formalize_language, + embodiment_tag=config.embodiment_tag, + embodiment_mapping=embodiment_mapping, + normalize_min_max=True, + stats=padded_stats, + clip_outliers=clip_outliers, + video_modality_keys=video_modality_keys, + raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None, + modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, + state_cache_key=state_cache_key, + ) + + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + pack_step, + GrootN17VLMEncodeStep( + model_name=config.n1_7_backbone_model, + image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None, + image_target_size=checkpoint_assets.image_target_size + if checkpoint_assets is not None + else None, + shortest_image_edge=checkpoint_assets.shortest_image_edge + if checkpoint_assets is not None + else None, + crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None, + use_albumentations=checkpoint_assets.use_albumentations + if checkpoint_assets is not None + else False, + ), + DeviceProcessorStep(device=config.device), + ] + if checkpoint_assets is None: + action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep( + env_action_dim=env_action_dim, + stats=padded_stats, + normalize_min_max=True, + clip_normalized_action=True, + ) + else: + action_decode_step = GrootN17ActionDecodeStep( + env_action_dim=env_action_dim, + raw_stats=checkpoint_assets.raw_stats, + modality_config=checkpoint_assets.modality_config, + use_percentiles=checkpoint_assets.use_percentiles, + use_relative_action=checkpoint_assets.use_relative_action, + pack_step=pack_step, + state_cache_key=state_cache_key, + action_decode_transform=config.action_decode_transform, + ) + + output_steps: list[ProcessorStep] = [ + action_decode_step, + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + # Get horizon/dimension parameters from config # These should match the config used for the pretrained model # Default values match most GR00T configs (state_horizon=1, action_horizon=16) @@ -185,11 +682,79 @@ def make_groot_pre_post_processors( # GR00T specific processor steps -def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray: - # img_t: (B, C, H, W) float in [0,1] or uint8 +def _to_uint8_np_bthwc(img_t: torch.Tensor) -> np.ndarray: + # img_t: (B, C, H, W) or (B, T, C, H, W), float in [0,1] or uint8 if img_t.dtype.is_floating_point: img_t = (img_t.clamp(0, 1) * 255.0).to(torch.uint8) - return rearrange(img_t.cpu().numpy(), "b c h w -> b h w c") + if img_t.dim() == 4: + return rearrange(img_t.cpu().numpy(), "b c h w -> b 1 h w c") + if img_t.dim() == 5: + return rearrange(img_t.cpu().numpy(), "b t c h w -> b t h w c") + raise ValueError(f"Expected image tensor shape (B, C, H, W) or (B, T, C, H, W), got {tuple(img_t.shape)}") + + +def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray: + return _to_uint8_np_bthwc(img_t)[:, 0] + + +def _align_video_horizon(video: np.ndarray, horizon: int | None) -> np.ndarray: + """Match the checkpoint video horizon by truncating or left-padding frames.""" + + if horizon is None or horizon <= 0: + return video + current = video.shape[1] + if current == horizon: + return video + if current > horizon: + return video[:, -horizon:] + pad = np.repeat(video[:, :1], horizon - current, axis=1) + return np.concatenate([pad, video], axis=1) + + +def _infer_n1_7_batch_size_and_device( + obs: dict[str, Any], action: torch.Tensor | None +) -> tuple[int, torch.device]: + for value in list(obs.values()) + [action]: + if isinstance(value, torch.Tensor): + return value.shape[0], value.device + video = obs.get("video") + if isinstance(video, np.ndarray): + return video.shape[0], torch.device("cpu") + return 1, torch.device("cpu") + + +def _prepare_n1_7_language_batch( + language: Any, + batch_size: int, + *, + formalize_language: bool, +) -> list[str]: + default_language = "Perform the task." + if language is None or (isinstance(language, str) and language == ""): + languages = [default_language] * batch_size + elif isinstance(language, str): + languages = [language] * batch_size + elif isinstance(language, (list, tuple)): + languages = list(language) + if len(languages) == 0: + languages = [default_language] * batch_size + elif len(languages) == 1 and batch_size > 1: + languages = languages * batch_size + elif len(languages) != batch_size: + raise ValueError( + f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}." + ) + else: + languages = [str(language)] * batch_size + + formatted = [] + for item in languages: + text = str(item) if item else default_language + if formalize_language: + text = text.lower() + text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_") + formatted.append(text) + return formatted def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin: @@ -215,6 +780,116 @@ def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS return proc +def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> ProcessorMixin: + try: + from transformers import ( + AutoTokenizer, + Qwen2VLImageProcessorFast, + Qwen3VLProcessor, + Qwen3VLVideoProcessor, + ) + except ImportError as exc: + raise ImportError( + "GR00T N1.7 preprocessing requires a transformers version with Qwen3-VL processor support. " + "Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`." + ) from exc + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True) + video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True) + proc = Qwen3VLProcessor( + image_processor=image_processor, + tokenizer=tokenizer, + video_processor=video_processor, + chat_template=tokenizer.chat_template, + ) + proc.tokenizer.padding_side = "left" + return proc + + +def _transform_n1_7_image_for_vlm( + image: Image.Image, + *, + image_crop_size: list[int] | None, + image_target_size: list[int] | None, + shortest_image_edge: int | None, + crop_fraction: float | None, + use_albumentations: bool = False, +) -> Image.Image: + if image_target_size is None: + return image + + target_h, target_w = image_target_size + if image.mode != "RGB": + image = image.convert("RGB") + + if use_albumentations: + try: + import cv2 + except ImportError as exc: + raise ImportError( + "GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless." + ) from exc + + image_np = np.asarray(image) + height, width = image_np.shape[:2] + if height != width: + square_edge = max(height, width) + pad_h = square_edge - height + pad_w = square_edge - width + image_np = cv2.copyMakeBorder( + image_np, + pad_h // 2, + pad_h - pad_h // 2, + pad_w // 2, + pad_w - pad_w // 2, + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + + resize_edge = shortest_image_edge or target_h + if image_np.shape[:2] != (resize_edge, resize_edge): + image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA) + + if crop_fraction is None and image_crop_size is not None: + crop_fraction = image_crop_size[0] / float(target_h) + if crop_fraction is not None and 0.0 < crop_fraction < 1.0: + height, width = image_np.shape[:2] + crop_h = max(1, int(height * crop_fraction)) + crop_w = max(1, int(width * crop_fraction)) + top = max(0, (height - crop_h) // 2) + left = max(0, (width - crop_w) // 2) + image_np = image_np[top : top + crop_h, left : left + crop_w] + + if image_np.shape[:2] != (target_h, target_w): + image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA) + return Image.fromarray(image_np) + + square_edge = max(image.width, image.height) + if image.width != image.height: + padded = Image.new("RGB", (square_edge, square_edge)) + left = (square_edge - image.width) // 2 + top = (square_edge - image.height) // 2 + padded.paste(image, (left, top)) + image = padded + + resize_edge = shortest_image_edge or target_h + image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC) + + if crop_fraction is None and image_crop_size is not None: + crop_fraction = image_crop_size[0] / float(target_h) + if crop_fraction is not None and 0.0 < crop_fraction < 1.0: + crop_w = max(1, int(round(image.width * crop_fraction))) + crop_h = max(1, int(round(image.height * crop_fraction))) + left = max(0, (image.width - crop_w) // 2) + top = max(0, (image.height - crop_h) // 2) + image = image.crop((left, top, left + crop_w, top + crop_h)) + + if image.size != (target_w, target_h): + image = image.resize((target_w, target_h), Image.Resampling.BICUBIC) + return image + + @dataclass @ProcessorStepRegistry.register(name="groot_pack_inputs_v3") class GrootPackInputsStep(ProcessorStep): @@ -437,6 +1112,354 @@ class GrootPackInputsStep(ProcessorStep): self.stats = reconstructed +@dataclass +@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1") +class GrootN17PackInputsStep(ProcessorStep): + """Pack LeRobot transitions into the raw tensor layout expected by N1.7. + + This step preserves the checkpoint's camera order, video horizon, language + formatting, normalization statistics, action mask semantics, and embodiment + id mapping before the Qwen3-VL processor sees the sample. + """ + + state_horizon: int = 1 + action_horizon: int = 40 + valid_action_horizon: int = 40 + video_horizon: int | None = None + max_state_dim: int = 132 + max_action_dim: int = 132 + language_key: str = "task" + formalize_language: bool = True + embodiment_tag: str = "new_embodiment" + embodiment_mapping: dict[str, int] = field(default_factory=lambda: dict(N1_7_EMBODIMENT_MAPPING)) + normalize_min_max: bool = True + stats: dict[str, dict[str, Any]] | None = None + clip_outliers: bool = True + video_modality_keys: list[str] | None = None + raw_stats: dict[str, Any] | None = None + modality_config: dict[str, Any] | None = None + state_cache_key: str = "" + _last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False) + + def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]: + available = {key for key in obs if key.startswith(OBS_IMAGES)} + if not available and OBS_IMAGE in obs: + return [OBS_IMAGE] + if not self.video_modality_keys: + return sorted(available) + + ordered: list[str] = [] + for modality_key in self.video_modality_keys: + candidates = [f"{OBS_IMAGES}.{modality_key}"] + if modality_key == "wrist_image": + candidates.append(f"{OBS_IMAGES}.image2") + elif modality_key == "image": + candidates.append(f"{OBS_IMAGES}.image") + + match = next((candidate for candidate in candidates if candidate in available), None) + if match is not None: + ordered.append(match) + + if not ordered: + return sorted(available) + return ordered + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION, {}) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} + + def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor: + t = torch.as_tensor(vec) + t = t.flatten().to( + dtype=torch.float32, + device=next( + (v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu") + ), + ) + d = int(t.shape[-1]) if t.numel() > 0 else 0 + if d == target_dim: + return t + if d < target_dim: + pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device) + return torch.cat([t, pad], dim=0) + return t[:target_dim] + + def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor: + if not self.normalize_min_max or self.stats is None or key not in self.stats: + return x + stats_k = self.stats[key] + last_dim = x.shape[-1] + min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0) + max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0) + denom = max_v - min_v + mask = denom != 0 + safe_denom = torch.where(mask, denom, torch.ones_like(denom)) + mapped = 2 * (x - min_v) / safe_denom - 1 + normalized = torch.where(mask, mapped, torch.zeros_like(mapped)) + if self.clip_outliers: + normalized = normalized.clamp(-1.0, 1.0) + return normalized + + def _cache_raw_state(state: torch.Tensor) -> None: + if self.modality_config is None or self.raw_stats is None: + return + state_config = self.modality_config.get("state", {}) + if not isinstance(state_config, dict): + return + state_keys = state_config.get("modality_keys", []) + if not isinstance(state_keys, list): + return + + raw_state = state.detach().cpu().float().numpy() + start_idx = 0 + grouped: dict[str, np.ndarray] = {} + for key in state_keys: + if not isinstance(key, str): + continue + key_stats = self.raw_stats.get("state", {}).get(key, {}) + dim = len(key_stats.get("mean") or key_stats.get("min") or key_stats.get("q01") or []) + if dim <= 0: + continue + grouped[key] = raw_state[:, start_idx : start_idx + dim] + start_idx += dim + if grouped: + self._last_raw_state = grouped + _N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped + + img_keys = self._ordered_image_keys(obs) + if img_keys: + cams = [_align_video_horizon(_to_uint8_np_bthwc(obs[k]), self.video_horizon) for k in img_keys] + video = np.stack(cams, axis=2) # (B, T, V, H, W, C) + obs["video"] = video + image_keys_to_remove = [key for key in obs if key.startswith(OBS_IMAGES)] + if OBS_IMAGE in obs: + image_keys_to_remove.append(OBS_IMAGE) + for k in image_keys_to_remove: + obs.pop(k, None) + + bsz, _device = _infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION)) + comp["language"] = _prepare_n1_7_language_batch( + comp.get(self.language_key), + bsz, + formalize_language=self.formalize_language, + ) + + if OBS_STATE in obs: + state = obs[OBS_STATE] + if state.dim() != 2: + raise ValueError(f"state must be (B, D), got {tuple(state.shape)}") + bsz, dim = state.shape + if dim > self.max_state_dim: + raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.") + _cache_raw_state(state) + if self.normalize_min_max: + state = _min_max_norm(state, OBS_STATE) + state = state.unsqueeze(1) + if dim < self.max_state_dim: + pad = torch.zeros(bsz, 1, self.max_state_dim - dim, dtype=state.dtype, device=state.device) + state = torch.cat([state, pad], dim=2) + obs["state"] = state + + action = transition.get(TransitionKey.ACTION) + if isinstance(action, torch.Tensor): + if action.dim() == 2: + action = action.unsqueeze(1) + elif action.dim() == 3: + pass + else: + raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}") + + bsz, horizon, dim = action.shape + if horizon > self.action_horizon: + raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.") + if dim > self.max_action_dim: + raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.") + if self.normalize_min_max: + flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION) + action = flat.view(bsz, horizon, dim) + valid_dim = min(dim, self.max_action_dim) + valid_horizon = min(horizon, self.valid_action_horizon, self.action_horizon) + if dim < self.max_action_dim: + pad = torch.zeros( + bsz, horizon, self.max_action_dim - dim, dtype=action.dtype, device=action.device + ) + action = torch.cat([action, pad], dim=2) + if horizon < self.action_horizon: + pad = torch.zeros( + bsz, + self.action_horizon - horizon, + self.max_action_dim, + dtype=action.dtype, + device=action.device, + ) + action = torch.cat([action, pad], dim=1) + horizon = self.action_horizon + if valid_horizon < horizon: + action = action.clone() + action[:, valid_horizon:, :] = 0 + action_mask = torch.zeros( + bsz, horizon, self.max_action_dim, dtype=torch.float32, device=action.device + ) + action_mask[:, :valid_horizon, :valid_dim] = 1.0 + transition[TransitionKey.ACTION] = action + comp["action_mask"] = action_mask + + emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0) + bsz, device = _infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION)) + if "action_mask" not in comp: + action_mask = torch.zeros(bsz, self.action_horizon, dtype=torch.float32, device=device) + valid_horizon = min(self.valid_action_horizon, self.action_horizon) + action_mask[:, :valid_horizon] = 1.0 + comp["action_mask"] = action_mask + comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device) + + transition[TransitionKey.OBSERVATION] = obs + transition[TransitionKey.COMPLEMENTARY_DATA] = comp + return transition + + def transform_features(self, features): + return features + + def get_config(self) -> dict[str, Any]: + return { + "state_horizon": self.state_horizon, + "action_horizon": self.action_horizon, + "valid_action_horizon": self.valid_action_horizon, + "video_horizon": self.video_horizon, + "max_state_dim": self.max_state_dim, + "max_action_dim": self.max_action_dim, + "language_key": self.language_key, + "formalize_language": self.formalize_language, + "embodiment_tag": self.embodiment_tag, + "embodiment_mapping": self.embodiment_mapping, + "normalize_min_max": self.normalize_min_max, + "clip_outliers": self.clip_outliers, + "video_modality_keys": self.video_modality_keys, + "raw_stats": self.raw_stats, + "modality_config": self.modality_config, + "state_cache_key": self.state_cache_key, + } + + def get_cached_raw_state(self) -> dict[str, np.ndarray] | None: + """Return the latest unnormalized state split by checkpoint modality key.""" + + return self._last_raw_state + + def state_dict(self) -> dict[str, torch.Tensor]: + if not self.stats: + return {} + + flat: dict[str, torch.Tensor] = {} + for key, sub in self.stats.items(): + for stat_name, value in sub.items(): + flat[f"{key}.{stat_name}"] = torch.as_tensor(value).cpu() + return flat + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + if not state: + return + reconstructed: dict[str, dict[str, Any]] = {} + for flat_key, tensor in state.items(): + if "." in flat_key: + key, stat_name = flat_key.rsplit(".", 1) + reconstructed.setdefault(key, {})[stat_name] = tensor + if reconstructed: + self.stats = reconstructed + + +@dataclass +@ProcessorStepRegistry.register(name="groot_n1_7_vlm_encode_v1") +class GrootN17VLMEncodeStep(ProcessorStep): + """Tokenize N1.7's packed video-language prompt with the Qwen3-VL processor. + + The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes + an image item in the same chat message so the resulting image tokens match + the temporal VLM packing used by Isaac-GR00T. + """ + + model_name: str = GROOT_N1_7_BACKBONE_MODEL + image_crop_size: list[int] | None = None + image_target_size: list[int] | None = None + shortest_image_edge: int | None = None + crop_fraction: float | None = None + use_albumentations: bool = False + _proc: ProcessorMixin | None = field(default=None, init=False, repr=False) + + @property + def proc(self) -> ProcessorMixin: + if self._proc is None: + self._proc = _build_n1_7_processor(self.model_name) + return self._proc + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION, {}) or {} + comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} + video = obs.get("video") + if video is None: + return transition + + languages = _prepare_n1_7_language_batch( + comp.get("language"), + video.shape[0], + formalize_language=False, + ) + + texts: list[str] = [] + images: list[Image.Image] = [] + for batch_idx in range(video.shape[0]): + sample = video[batch_idx] # (T, V, H, W, C) + sample_images = [ + _transform_n1_7_image_for_vlm( + Image.fromarray(sample[timestep, view_idx]), + image_crop_size=self.image_crop_size, + image_target_size=self.image_target_size, + shortest_image_edge=self.shortest_image_edge, + crop_fraction=self.crop_fraction, + use_albumentations=self.use_albumentations, + ) + for timestep in range(sample.shape[0]) + for view_idx in range(sample.shape[1]) + ] + conversation = [ + { + "role": "user", + "content": [ + *[{"type": "image", "image": image} for image in sample_images], + {"type": "text", "text": languages[batch_idx]}, + ], + } + ] + texts.append( + self.proc.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=False, + ) + ) + images.extend(sample_images) + + encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True) + for key, value in encoded.items(): + comp[key] = value + obs.pop("video", None) + transition[TransitionKey.OBSERVATION] = obs + transition[TransitionKey.COMPLEMENTARY_DATA] = comp + return transition + + def transform_features(self, features): + return features + + def get_config(self) -> dict[str, Any]: + return { + "model_name": self.model_name, + "image_crop_size": self.image_crop_size, + "image_target_size": self.image_target_size, + "shortest_image_edge": self.shortest_image_edge, + "crop_fraction": self.crop_fraction, + "use_albumentations": self.use_albumentations, + } + + @dataclass @ProcessorStepRegistry.register(name="groot_eagle_encode_v3") class GrootEagleEncodeStep(ProcessorStep): @@ -572,6 +1595,243 @@ class GrootEagleCollateStep(ProcessorStep): return features +def _stat_dim_from_entry(entry: dict[str, Any]) -> int: + for stat_name in ("mean", "q01", "min", "max", "std"): + value = entry.get(stat_name) + if isinstance(value, list) and len(value) > 0: + return len(value) + return 0 + + +def _n1_7_decode_stats_for_action( + raw_stats: dict[str, Any], + key: str, + action_config: dict[str, Any], + *, + use_relative_action: bool, + use_percentiles: bool, +) -> tuple[np.ndarray, np.ndarray]: + """Select the min/max arrays needed to decode one checkpoint action group.""" + + is_relative = use_relative_action and _config_value(action_config.get("rep")) == "relative" + modality = "relative_action" if is_relative else "action" + stats = raw_stats.get(modality, {}).get(key, {}) + if not isinstance(stats, dict): + raise KeyError(f"Missing N1.7 statistics for {modality}.{key}") + min_name = "min" if is_relative else ("q01" if use_percentiles else "min") + max_name = "max" if is_relative else ("q99" if use_percentiles else "max") + if min_name not in stats or max_name not in stats: + raise KeyError(f"Missing '{min_name}'/'{max_name}' statistics for {modality}.{key}") + return np.asarray(stats[min_name], dtype=np.float32), np.asarray(stats[max_name], dtype=np.float32) + + +def _unnormalize_min_max(action: np.ndarray, min_v: np.ndarray, max_v: np.ndarray) -> np.ndarray: + return (np.clip(action, -1.0, 1.0) + 1.0) * 0.5 * (max_v - min_v) + min_v + + +def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray: + rows = rot6d.reshape(2, 3).astype(np.float64) + row1 = rows[0] / np.linalg.norm(rows[0]) + row2 = rows[1] - np.dot(row1, rows[1]) * row1 + row2 = row2 / np.linalg.norm(row2) + row3 = np.cross(row1, row2) + return np.vstack([row1, row2, row3]) + + +def _xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray: + transform = np.eye(4, dtype=np.float64) + transform[:3, :3] = _rot6d_to_matrix(xyz_rot6d[3:]) + transform[:3, 3] = xyz_rot6d[:3] + return transform + + +def _homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray: + return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0) + + +def _relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray: + """Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses.""" + + out = np.empty_like(action, dtype=np.float64) + for batch_idx in range(action.shape[0]): + reference = _xyz_rot6d_to_homogeneous(reference_state[batch_idx]) + for timestep in range(action.shape[1]): + relative = _xyz_rot6d_to_homogeneous(action[batch_idx, timestep]) + out[batch_idx, timestep] = _homogeneous_to_xyz_rot6d(reference @ relative) + return out.astype(np.float32) + + +def _n1_7_action_group_slice( + action_keys: list[Any], decoded_groups: dict[str, np.ndarray], target_key: str +) -> slice: + start_idx = 0 + for key in action_keys: + if not isinstance(key, str) or key not in decoded_groups: + continue + dim = decoded_groups[key].shape[-1] + end_idx = start_idx + dim + if key == target_key: + return slice(start_idx, end_idx) + start_idx = end_idx + + raise KeyError(f"Missing N1.7 action group '{target_key}' required by action decode transform.") + + +def _apply_n1_7_action_decode_transform( + decoded: np.ndarray, + *, + transform: str | None, + action_keys: list[Any], + decoded_groups: dict[str, np.ndarray], +) -> np.ndarray: + if transform is None: + return decoded + + if transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO: + gripper_slice = _n1_7_action_group_slice(action_keys, decoded_groups, "gripper") + if gripper_slice.stop is None or gripper_slice.stop > decoded.shape[-1]: + raise ValueError( + "N1.7 LIBERO action decode transform requested, but the decoded gripper action " + "is outside the sliced environment action." + ) + if gripper_slice.stop - gripper_slice.start != 1: + raise ValueError("N1.7 LIBERO action decode transform expects a scalar gripper action.") + + transformed = decoded.copy() + gripper = transformed[..., gripper_slice] + transformed[..., gripper_slice] = -np.sign(2.0 * gripper - 1.0) + return transformed + + raise ValueError(f"Unsupported N1.7 action decode transform '{transform}'.") + + +@dataclass +@ProcessorStepRegistry.register(name="groot_n1_7_action_decode_v1") +class GrootN17ActionDecodeStep(ProcessorStep): + """Decode the full 132-D N1.7 model action back to environment actions. + + N1.7 predicts checkpoint-order action groups. This step unnormalizes each + group with the checkpoint stats, converts relative groups to absolute values + using the raw state cached during packing, concatenates groups in checkpoint + order, and finally slices to the environment action dimension. + """ + + env_action_dim: int = 0 + raw_stats: dict[str, Any] | None = None + modality_config: dict[str, Any] | None = None + use_percentiles: bool = False + use_relative_action: bool = False + state_cache_key: str = "" + action_decode_transform: str | None = None + pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, torch.Tensor): + return transition + if self.raw_stats is None or self.modality_config is None: + return transition + + action_config = self.modality_config.get("action", {}) + if not isinstance(action_config, dict): + return transition + action_keys = action_config.get("modality_keys", []) + action_configs = action_config.get("action_configs", []) + if not isinstance(action_keys, list) or not isinstance(action_configs, list): + return transition + + action_np = action.detach().cpu().float().numpy() + decoded_groups: dict[str, np.ndarray] = {} + start_idx = 0 + for idx, key in enumerate(action_keys): + if not isinstance(key, str): + continue + stats_entry = self.raw_stats.get("action", {}).get(key, {}) + if not isinstance(stats_entry, dict): + continue + dim = _stat_dim_from_entry(stats_entry) + if dim <= 0: + continue + cfg = ( + action_configs[idx] + if idx < len(action_configs) and isinstance(action_configs[idx], dict) + else {} + ) + normalized = action_np[..., start_idx : start_idx + dim] + min_v, max_v = _n1_7_decode_stats_for_action( + self.raw_stats, + key, + cfg, + use_relative_action=self.use_relative_action, + use_percentiles=self.use_percentiles, + ) + decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v) + start_idx += dim + + if self.use_relative_action: + raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None + if raw_state is None: + raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key)) + if raw_state is None: + raise RuntimeError( + "GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep " + "to convert relative N1.7 actions back to absolute actions." + ) + for idx, key in enumerate(action_keys): + if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs): + continue + cfg = action_configs[idx] + if not isinstance(cfg, dict) or _config_value(cfg.get("rep")) != "relative": + continue + state_key = cfg.get("state_key") or key + if state_key not in raw_state: + raise KeyError(f"Missing cached raw state '{state_key}' for relative N1.7 action '{key}'") + reference = raw_state[state_key] + action_type = _config_value(cfg.get("type")) + action_format = _config_value(cfg.get("format")) + if action_type == "non_eef": + decoded_groups[key] = decoded_groups[key] + reference[:, None, :] + elif action_type == "eef" and action_format == "xyz+rot6d": + decoded_groups[key] = _relative_eef_to_absolute(decoded_groups[key], reference) + else: + raise ValueError(f"Unsupported relative N1.7 action config for '{key}': {cfg}") + + if not decoded_groups: + return transition + + decoded = np.concatenate( + [decoded_groups[key] for key in action_keys if isinstance(key, str) and key in decoded_groups], + axis=-1, + ) + if self.env_action_dim and decoded.shape[-1] > self.env_action_dim: + decoded = decoded[..., : self.env_action_dim] + decoded = _apply_n1_7_action_decode_transform( + decoded, + transform=self.action_decode_transform, + action_keys=action_keys, + decoded_groups=decoded_groups, + ) + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = torch.as_tensor( + decoded, dtype=action.dtype, device=action.device + ) + return new_transition + + def transform_features(self, features): + return features + + def get_config(self) -> dict[str, Any]: + return { + "env_action_dim": self.env_action_dim, + "raw_stats": self.raw_stats, + "modality_config": self.modality_config, + "use_percentiles": self.use_percentiles, + "use_relative_action": self.use_relative_action, + "state_cache_key": self.state_cache_key, + "action_decode_transform": self.action_decode_transform, + } + + @dataclass @ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1") class GrootActionUnpackUnnormalizeStep(ProcessorStep): @@ -579,6 +1839,9 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): # Apply inverse of min-max normalization if it was used in preprocessor normalize_min_max: bool = True stats: dict[str, dict[str, Any]] | None = None + clip_normalized_action: bool = False + libero_gripper_action: bool = False + libero_gripper_binarize: bool = True def __call__(self, transition: EnvTransition) -> EnvTransition: # Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model) @@ -586,10 +1849,9 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): if not isinstance(action, torch.Tensor): return transition - # Select last timestep and slice to env dimension - if action.dim() == 3: - action = action[:, -1, :] - # Now action is (B, D_model) + # Slice to env dimension while preserving an optional action horizon. + # Sync rollout postprocesses selected actions as (B, D); RTC postprocesses + # chunks as (B, T, D), matching Isaac-GR00T's decode_action contract. if self.env_action_dim and action.shape[-1] >= self.env_action_dim: action = action[..., : self.env_action_dim] @@ -597,6 +1859,8 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): # forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0 # inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min if self.normalize_min_max and self.stats is not None: + if self.clip_normalized_action: + action = action.clamp(-1.0, 1.0) stats_k = self.stats.get(ACTION, {}) d = action.shape[-1] min_v = torch.as_tensor( @@ -617,6 +1881,15 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): inv = (action + 1.0) * 0.5 * safe_denom + min_v action = torch.where(mask, inv, min_v) + if self.libero_gripper_action and action.shape[-1] >= 7: + gripper = action[..., -1] + if self.libero_gripper_binarize: + gripper = -torch.sign(2.0 * gripper - 1.0) + else: + gripper = -(2.0 * gripper - 1.0) + action = action.clone() + action[..., -1] = gripper + transition[TransitionKey.ACTION] = action return transition @@ -632,6 +1905,9 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): return { "env_action_dim": self.env_action_dim, "normalize_min_max": self.normalize_min_max, + "clip_normalized_action": self.clip_normalized_action, + "libero_gripper_action": self.libero_gripper_action, + "libero_gripper_binarize": self.libero_gripper_binarize, } def state_dict(self) -> dict[str, torch.Tensor]: diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py new file mode 100644 index 000000000..34c966e66 --- /dev/null +++ b/tests/policies/groot/test_groot_n1_7.py @@ -0,0 +1,1893 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +from types import SimpleNamespace +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from torch import nn + +from lerobot.configs import FeatureType, PolicyFeature +from lerobot.policies.factory import make_policy_config, make_pre_post_processors +from lerobot.policies.groot.configuration_groot import ( + GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + GROOT_N1_5, + GROOT_N1_5_BASE_MODEL, + GROOT_N1_7, + GROOT_N1_7_BASE_MODEL, + GrootConfig, + infer_groot_n1_7_action_execution_horizon, + infer_groot_n1_7_action_horizon, +) +from lerobot.policies.groot.modeling_groot import GrootPolicy +from lerobot.policies.groot.processor_groot import ( + GrootActionUnpackUnnormalizeStep, + GrootEagleEncodeStep, + GrootN17ActionDecodeStep, + GrootN17PackInputsStep, + GrootN17VLMEncodeStep, + _transform_n1_7_image_for_vlm, + make_groot_pre_post_processors, +) +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import TransitionKey +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +def _groot_features( + state_dim: int, action_dim: int +) -> tuple[dict[str, PolicyFeature], dict[str, PolicyFeature]]: + return ( + { + f"{OBS_IMAGES}.front": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 256, 256)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + }, + {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}, + ) + + +def _groot_config(model_version: str) -> GrootConfig: + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + kwargs = {} + if model_version == GROOT_N1_7: + kwargs["action_decode_transform"] = GROOT_ACTION_DECODE_TRANSFORM_LIBERO + return GrootConfig( + model_version=model_version, + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + **kwargs, + ) + + +def _raw_n1_7_libero_config(model_path) -> GrootConfig: + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + return GrootConfig( + model_version=GROOT_N1_7, + base_model_path=str(model_path), + embodiment_tag="libero_sim", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + ) + + +def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_ids(monkeypatch): + from transformers.feature_extraction_utils import BatchFeature + + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + class FakeLanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(1, 1) for _ in range(2)]) + + class FakeInnerModel(nn.Module): + def __init__(self): + super().__init__() + self.language_model = FakeLanguageModel() + self.visual = nn.Linear(1, 1) + + class FakeQwen3VLForConditionalGeneration(nn.Module): + config = SimpleNamespace(image_token_id=42, video_token_id=43) + + def __init__(self): + super().__init__() + self.model = FakeInnerModel() + self.forward_kwargs = None + + @classmethod + def from_pretrained(cls, *args, **kwargs): + return cls() + + @classmethod + def _from_config(cls, *args, **kwargs): + return cls() + + def eval(self): + super().eval() + return self + + def forward(self, **kwargs): + self.forward_kwargs = kwargs + assert "mm_token_type_ids" in kwargs + batch_size, sequence_length = kwargs["input_ids"].shape + features = torch.arange(batch_size * sequence_length * 4, dtype=torch.float32).view( + batch_size, sequence_length, 4 + ) + return SimpleNamespace(hidden_states=[features]) + + monkeypatch.setattr( + groot_n1_7, + "metadata", + SimpleNamespace(version=lambda package: "5.3.0" if package == "transformers" else "0"), + raising=False, + ) + monkeypatch.setattr(groot_n1_7, "Qwen3VLForConditionalGeneration", FakeQwen3VLForConditionalGeneration) + + backbone = groot_n1_7.Qwen3Backbone( + model_name="nvidia/Cosmos-Reason2-2B", + select_layer=1, + use_flash_attention=False, + ) + + assert len(backbone.language_model.layers) == 1 + output = backbone.forward( + BatchFeature( + data={ + "input_ids": torch.tensor([[1, 42, 2]]), + "attention_mask": torch.tensor([[1, 1, 0]]), + "mm_token_type_ids": torch.tensor([[0, 1, 0]]), + "pixel_values": torch.zeros(1, 3, 2, 2), + "image_grid_thw": torch.ones(1, 3, dtype=torch.long), + } + ) + ) + + assert backbone.model.forward_kwargs["mm_token_type_ids"].tolist() == [[0, 1, 0]] + assert output["backbone_features"].shape == (1, 3, 4) + + output = backbone.forward( + BatchFeature( + data={ + "input_ids": torch.tensor([[1, 42, 43, 2]]), + "attention_mask": torch.tensor([[1, 1, 1, 0]]), + "pixel_values": torch.zeros(1, 3, 2, 2), + "image_grid_thw": torch.ones(1, 3, dtype=torch.long), + "pixel_values_videos": torch.zeros(1, 3, 2, 2), + "video_grid_thw": torch.ones(1, 3, dtype=torch.long), + } + ) + ) + + assert backbone.model.forward_kwargs["mm_token_type_ids"].tolist() == [[0, 1, 2, 0]] + assert backbone.model.forward_kwargs["mm_token_type_ids"].dtype == torch.int32 + assert output["backbone_features"].shape == (1, 4, 4) + + +def test_n1_7_backbone_preserves_missing_qwen_optional_dependency_error(monkeypatch): + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + monkeypatch.setattr( + groot_n1_7, + "metadata", + SimpleNamespace(version=lambda package: "5.3.0" if package == "transformers" else "0"), + raising=False, + ) + monkeypatch.setattr(groot_n1_7, "Qwen3VLForConditionalGeneration", None) + + with pytest.raises(ImportError, match="Qwen3VLForConditionalGeneration is required"): + groot_n1_7.Qwen3Backbone( + model_name="nvidia/Cosmos-Reason2-2B", + select_layer=0, + use_flash_attention=False, + ) + + +def _write_raw_n1_7_libero_checkpoint(path): + path.mkdir() + (path / "config.json").write_text( + json.dumps( + { + "model_type": "Gr00tN1d7", + "architectures": ["Gr00tN1d7"], + "model_name": "nvidia/Cosmos-Reason2-2B", + "action_horizon": 40, + "max_state_dim": 132, + "max_action_dim": 132, + "image_target_size": [256, 256], + } + ) + ) + (path / "processor_config.json").write_text( + json.dumps( + { + "processor_class": "Gr00tN1d7Processor", + "processor_kwargs": { + "clip_outliers": True, + "formalize_language": True, + "image_crop_size": [230, 230], + "image_target_size": [256, 256], + "shortest_image_edge": 256, + "crop_fraction": 0.95, + "use_albumentations": True, + "max_action_horizon": 40, + "max_state_dim": 132, + "max_action_dim": 132, + "use_percentiles": True, + "use_relative_action": True, + "modality_configs": { + "libero_sim": { + "video": { + "delta_indices": [0], + "modality_keys": ["image", "wrist_image"], + }, + "state": { + "delta_indices": [0], + "modality_keys": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + }, + "action": { + "delta_indices": list(range(16)), + "modality_keys": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + }, + "language": { + "delta_indices": [0], + "modality_keys": ["annotation.human.action.task_description"], + }, + } + }, + }, + } + ) + ) + (path / "embodiment_id.json").write_text(json.dumps({"libero_sim": 42})) + (path / "statistics.json").write_text( + json.dumps( + { + "libero_sim": { + "state": { + "x": _stats([0.0]), + "y": _stats([1.0]), + "z": _stats([2.0]), + "roll": _stats([3.0]), + "pitch": _stats([4.0]), + "yaw": _stats([5.0]), + "gripper": _stats([6.0, 7.0]), + }, + "action": { + "x": _stats([10.0]), + "y": _stats([11.0]), + "z": _stats([12.0]), + "roll": _stats([13.0]), + "pitch": _stats([14.0]), + "yaw": _stats([15.0]), + "gripper": _stats([16.0]), + }, + "relative_action": {}, + } + } + ) + ) + + +def _stats(values): + return { + "min": values, + "max": [value + 100.0 for value in values], + "mean": [value + 50.0 for value in values], + "std": [1.0 for _ in values], + "q01": [value + 1.0 for value in values], + "q99": [value + 99.0 for value in values], + } + + +def _expected_albumentations_eval_image(image_np, cv2, *, target_size, shortest_edge, crop_fraction): + height, width = image_np.shape[:2] + if height != width: + square_edge = max(height, width) + pad_h = square_edge - height + pad_w = square_edge - width + image_np = cv2.copyMakeBorder( + image_np, + pad_h // 2, + pad_h - pad_h // 2, + pad_w // 2, + pad_w - pad_w // 2, + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + + image_np = cv2.resize(image_np, (shortest_edge, shortest_edge), interpolation=cv2.INTER_AREA) + crop_h = max(1, int(shortest_edge * crop_fraction)) + crop_w = max(1, int(shortest_edge * crop_fraction)) + top = (shortest_edge - crop_h) // 2 + left = (shortest_edge - crop_w) // 2 + image_np = image_np[top : top + crop_h, left : left + crop_w] + return cv2.resize(image_np, (target_size[1], target_size[0]), interpolation=cv2.INTER_AREA) + + +class _DummyGrootModel(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(())) + self.config = SimpleNamespace(compute_dtype="float32") + self.compute_dtype = "float32" + self.forward_inputs = None + + def forward(self, inputs): + self.forward_inputs = dict(inputs) + return {"loss": self.weight + 1.0} + + def get_action(self, inputs): + self.forward_inputs = dict(inputs) + batch_size = inputs["state"].shape[0] + return {"action_pred": torch.zeros(batch_size, 40, 132, device=self.weight.device)} + + +def test_groot_n1_5_defaults_are_preserved(): + config = GrootConfig(device="cpu") + + assert config.model_version == GROOT_N1_5 + assert config.base_model_path == GROOT_N1_5_BASE_MODEL + assert config.max_state_dim == 64 + assert config.max_action_dim == 32 + assert len(config.action_delta_indices) == 16 + + +def test_groot_n1_7_explicit_selection_uses_n1_7_defaults(): + config = GrootConfig(model_version=GROOT_N1_7, device="cpu") + + assert config.model_version == GROOT_N1_7 + assert config.base_model_path == GROOT_N1_7_BASE_MODEL + assert config.max_state_dim == 132 + assert config.max_action_dim == 132 + assert config.chunk_size == 40 + assert config.n_action_steps == 40 + assert len(config.action_delta_indices) == 40 + + +def test_groot_n1_7_accepts_named_action_decode_transform(): + config = GrootConfig( + model_version=GROOT_N1_7, + action_decode_transform="libero", + device="cpu", + ) + + assert config.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + + +@pytest.mark.parametrize("legacy_transform", ["libero_gripper", "libero-gripper"]) +def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy_transform): + with pytest.raises(ValueError, match="Unsupported GR00T N1.7 action decode transform"): + GrootConfig( + model_version=GROOT_N1_7, + action_decode_transform=legacy_transform, + device="cpu", + ) + + +def test_groot_n1_5_rejects_action_decode_transform(): + with pytest.raises(ValueError, match="action_decode_transform"): + GrootConfig( + model_version=GROOT_N1_5, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + device="cpu", + ) + + +def test_groot_n1_7_path_requires_matching_model_version(): + with pytest.raises(ValueError, match="model_version"): + GrootConfig(base_model_path=GROOT_N1_7_BASE_MODEL, device="cpu") + + +def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7(): + with pytest.raises(ValueError, match="does not match base_model_path"): + GrootConfig( + model_version=GROOT_N1_7, + base_model_path=GROOT_N1_5_BASE_MODEL, + device="cpu", + ) + + +def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_gr00t(): + sys.modules.pop("gr00t", None) + + config = make_policy_config("groot", model_version=GROOT_N1_7, device="cpu") + + assert isinstance(config, GrootConfig) + assert config.model_version == GROOT_N1_7 + assert "gr00t" not in sys.modules + + +def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): + model_path = tmp_path / "GR00T-N1.7-local" + model_path.mkdir() + config = _groot_config(GROOT_N1_5) + + with pytest.raises(ValueError, match="does not match base_model_path"): + GrootPolicy.from_pretrained(model_path, config=config) + + +def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + model_path = tmp_path / "GR00T-N1.7-local" + model_path.mkdir() + config = _groot_config(GROOT_N1_7) + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel())) + + policy = GrootPolicy.from_pretrained(model_path, config=config) + + assert policy.config.model_version == GROOT_N1_7 + assert policy.config.base_model_path == str(model_path) + + +def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path, monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + model_path = tmp_path / "local-checkpoint" + model_path.mkdir() + (model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}') + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel())) + + policy = GrootPolicy.from_pretrained(model_path) + + assert policy.config.model_version == GROOT_N1_7 + assert policy.config.base_model_path == str(model_path) + + +def test_raw_n1_7_libero_checkpoint_processors_use_checkpoint_assets(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + + pack_inputs = next(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + vlm_encode = next(step for step in preprocessor.steps if isinstance(step, GrootN17VLMEncodeStep)) + decode_actions = next(step for step in postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)) + + assert pack_inputs.embodiment_tag == "libero_sim" + assert pack_inputs.embodiment_mapping["libero_sim"] == 42 + assert pack_inputs.formalize_language is True + assert pack_inputs.valid_action_horizon == 16 + assert pack_inputs.action_horizon == 40 + assert pack_inputs.max_state_dim == 132 + assert pack_inputs.max_action_dim == 132 + assert pack_inputs.clip_outliers is True + assert pack_inputs.video_modality_keys == ["image", "wrist_image"] + assert pack_inputs.stats[OBS_STATE]["min"] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + assert pack_inputs.stats[OBS_STATE]["max"] == [ + 99.0, + 100.0, + 101.0, + 102.0, + 103.0, + 104.0, + 105.0, + 106.0, + ] + assert pack_inputs.stats[ACTION]["min"] == [11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0] + assert vlm_encode.image_crop_size == [230, 230] + assert vlm_encode.image_target_size == [256, 256] + assert vlm_encode.shortest_image_edge == 256 + assert vlm_encode.crop_fraction == 0.95 + assert vlm_encode.use_albumentations is True + assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0] + assert decode_actions.env_action_dim == 7 + assert decode_actions.use_percentiles is True + assert decode_actions.use_relative_action is True + assert decode_actions.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + + +def test_raw_n1_7_checkpoint_requires_percentile_stats_when_config_uses_percentiles(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + statistics = json.loads((model_path / "statistics.json").read_text()) + del statistics["libero_sim"]["state"]["x"]["q01"] + (model_path / "statistics.json").write_text(json.dumps(statistics)) + config = _raw_n1_7_libero_config(model_path) + + with pytest.raises(KeyError, match="q01.*state.x"): + make_pre_post_processors(config, pretrained_path=str(model_path)) + + +def test_raw_n1_7_checkpoint_processors_prefer_checkpoint_stats_when_dataset_stats_supplied(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + dataset_stats = { + OBS_STATE: { + "min": torch.full((8,), -8.0), + "max": torch.full((8,), 8.0), + }, + ACTION: { + "min": torch.full((7,), -7.0), + "max": torch.full((7,), 7.0), + }, + } + + preprocessor, postprocessor = make_pre_post_processors( + config, + pretrained_path=str(model_path), + dataset_stats=dataset_stats, + ) + + pack_inputs = next(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + decode_actions = next(step for step in postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)) + torch.testing.assert_close( + torch.as_tensor(pack_inputs.stats[OBS_STATE]["min"]), + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + ) + torch.testing.assert_close( + torch.as_tensor(pack_inputs.stats[ACTION]["max"]), + torch.tensor([109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0]), + ) + assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0] + assert decode_actions.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + + +def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "saved_processors" + + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + loaded_preprocessor = PolicyProcessorPipeline.from_pretrained( + save_dir, + config_filename="policy_preprocessor.json", + ) + loaded_postprocessor = PolicyProcessorPipeline.from_pretrained( + save_dir, + config_filename="policy_postprocessor.json", + ) + pack_inputs = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + decode_actions = next( + step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep) + ) + + assert pack_inputs.valid_action_horizon == 16 + assert pack_inputs.action_horizon == 40 + assert pack_inputs.video_modality_keys == ["image", "wrist_image"] + assert pack_inputs.clip_outliers is True + torch.testing.assert_close( + pack_inputs.stats[OBS_STATE]["min"], + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), + ) + assert decode_actions.env_action_dim == 7 + assert decode_actions.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_LIBERO + assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0] + + +def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max(): + step = GrootN17PackInputsStep( + max_state_dim=2, + max_action_dim=4, + normalize_min_max=False, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.zeros(1, 3), + }, + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + with pytest.raises(ValueError, match="State dimension 3 exceeds max_state_dim 2"): + step(transition) + + +def test_groot_n1_7_pack_inputs_rejects_action_shape_above_core_limits(): + step = GrootN17PackInputsStep( + action_horizon=2, + max_state_dim=2, + max_action_dim=2, + normalize_min_max=False, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.zeros(1, 2), + }, + TransitionKey.ACTION: torch.zeros(1, 2, 3), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + with pytest.raises(ValueError, match="Action dimension 3 exceeds max_action_dim 2"): + step(transition) + + transition[TransitionKey.ACTION] = torch.zeros(1, 3, 2) + with pytest.raises(ValueError, match="Action horizon 3 exceeds action_horizon 2"): + step(transition) + + +def test_groot_n1_7_pack_inputs_clips_and_masks_only_valid_action_horizon(): + step = GrootN17PackInputsStep( + action_horizon=40, + valid_action_horizon=16, + max_state_dim=4, + max_action_dim=4, + normalize_min_max=True, + clip_outliers=True, + stats={ + OBS_STATE: {"min": [0.0, 0.0], "max": [1.0, 1.0]}, + ACTION: {"min": [0.0, 0.0], "max": [1.0, 1.0]}, + }, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.tensor([[2.0, -1.0]]), + }, + TransitionKey.ACTION: torch.full((1, 16, 2), 1.0), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + output = step(transition) + + torch.testing.assert_close( + output[TransitionKey.OBSERVATION]["state"][0, 0, :2], + torch.tensor([1.0, -1.0]), + ) + assert output[TransitionKey.ACTION].shape == (1, 40, 4) + torch.testing.assert_close(output[TransitionKey.ACTION][0, 16:], torch.zeros(24, 4)) + action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"] + assert action_mask.shape == (1, 40, 4) + assert action_mask[0, :16, :2].sum().item() == 32 + assert action_mask[0, 16:].sum().item() == 0 + assert action_mask[0, :, 2:].sum().item() == 0 + + +def test_groot_n1_7_pack_inputs_normalizes_state_with_q01_q99_clips_and_pads(): + step = GrootN17PackInputsStep( + action_horizon=4, + max_state_dim=6, + max_action_dim=7, + normalize_min_max=True, + clip_outliers=True, + stats={ + OBS_STATE: { + "min": [0.0, 10.0, -2.0, 4.0], + "max": [10.0, 10.0, 2.0, 8.0], + } + }, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.tensor([[5.0, 42.0, -6.0, 10.0]]), + }, + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + output = step(transition) + + expected = torch.tensor([[[0.0, 0.0, -1.0, 1.0, 0.0, 0.0]]]) + torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected) + + +def test_groot_n1_7_libero_open_gripper_state_normalizes_near_core_oracle(): + step = GrootN17PackInputsStep( + action_horizon=40, + max_state_dim=132, + max_action_dim=7, + normalize_min_max=True, + clip_outliers=True, + stats={ + OBS_STATE: { + "min": [ + -0.27276572585105896, + -0.237214133143425, + 0.916006326675415, + 2.779496669769287, + -1.3187512159347534, + -0.4198998212814331, + 0.001503719249740243, + -0.03989770635962486, + ], + "max": [ + 0.1352936029434204, + 0.362916499376297, + 1.286232590675354, + 3.2829697132110596, + 0.9332759976387024, + 0.6325722336769104, + 0.03993396461009979, + -0.0016719202976673841, + ], + } + }, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.tensor( + [ + [ + -0.20846466720104218, + 0.0, + 1.1732795238494873, + 3.1403393745422363, + 0.0007735038525424898, + -0.0892220064997673, + 0.020833000540733337, + -0.020833000540733337, + ] + ] + ), + }, + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + output = step(transition) + + normalized = output[TransitionKey.OBSERVATION]["state"][0, 0, :8] + expected = torch.tensor( + [ + -0.6848445534706116, + -0.2094583511352539, + 0.3898160457611084, + 0.4334142208099365, + 0.17185509204864502, + -0.3716168999671936, + 0.005941033363342285, + -0.002521216869354248, + ] + ) + torch.testing.assert_close(normalized, expected, atol=1e-6, rtol=1e-6) + assert normalized[6:].abs().max().item() < 0.01 + + +def test_groot_n1_7_pack_inputs_normalizes_action_chunk_per_dimension_before_padding(): + step = GrootN17PackInputsStep( + action_horizon=5, + valid_action_horizon=3, + max_state_dim=4, + max_action_dim=5, + normalize_min_max=True, + clip_outliers=True, + stats={ + OBS_STATE: {"min": [0.0, 0.0], "max": [1.0, 1.0]}, + ACTION: { + "min": [-2.0, 10.0, 100.0], + "max": [2.0, 30.0, 101.0], + }, + }, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.tensor([[0.5, 0.5]]), + }, + TransitionKey.ACTION: torch.tensor( + [ + [ + [-2.0, 30.0, 100.25], + [0.0, 20.0, 101.0], + [2.0, 10.0, 100.0], + ] + ] + ), + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + output = step(transition) + + expected_actions = torch.tensor( + [ + [ + [-1.0, 1.0, -0.5, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [1.0, -1.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ] + ) + torch.testing.assert_close(output[TransitionKey.ACTION], expected_actions) + action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"] + assert action_mask.shape == (1, 5, 5) + assert action_mask[0, :3, :3].sum().item() == 9 + assert action_mask[0, 3:].sum().item() == 0 + assert action_mask[0, :, 3:].sum().item() == 0 + + +def test_groot_n1_7_pack_inputs_adds_inference_action_horizon_mask(): + step = GrootN17PackInputsStep( + action_horizon=40, + valid_action_horizon=16, + max_state_dim=8, + max_action_dim=7, + normalize_min_max=False, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.zeros(2, 8), + }, + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move", "Place"]}, + } + + output = step(transition) + + action_mask = output[TransitionKey.COMPLEMENTARY_DATA]["action_mask"] + assert action_mask.shape == (2, 40) + assert action_mask[:, :16].sum().item() == 32 + assert action_mask[:, 16:].sum().item() == 0 + assert output[TransitionKey.COMPLEMENTARY_DATA]["embodiment_id"].dtype == torch.int32 + + +def test_groot_n1_7_pack_inputs_orders_video_by_checkpoint_modality_keys(): + step = GrootN17PackInputsStep( + normalize_min_max=False, + video_modality_keys=["image", "wrist_image"], + ) + transition = { + TransitionKey.OBSERVATION: { + f"{OBS_IMAGES}.zz_extra": torch.full((1, 3, 2, 2), 33, dtype=torch.uint8), + f"{OBS_IMAGES}.image2": torch.full((1, 3, 2, 2), 22, dtype=torch.uint8), + f"{OBS_IMAGES}.image": torch.full((1, 3, 2, 2), 11, dtype=torch.uint8), + OBS_STATE: torch.zeros(1, 8), + }, + TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]}, + } + + output = step(transition) + + video = output[TransitionKey.OBSERVATION]["video"] + assert video.shape == (1, 1, 2, 2, 2, 3) + assert np.unique(video[0, 0, 0]).tolist() == [11] + assert np.unique(video[0, 0, 1]).tolist() == [22] + assert f"{OBS_IMAGES}.zz_extra" not in output[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGES}.image" not in output[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGES}.image2" not in output[TransitionKey.OBSERVATION] + + +def test_groot_n1_7_postprocessor_clips_normalized_action_before_unnormalizing(): + step = GrootActionUnpackUnnormalizeStep( + env_action_dim=3, + normalize_min_max=True, + clip_normalized_action=True, + stats={ + ACTION: { + "min": [0.0, 0.0, 0.0], + "max": [10.0, 10.0, 10.0], + } + }, + ) + transition = { + TransitionKey.ACTION: torch.tensor([[-2.0, 0.0, 2.0]]), + } + + output = step(transition) + + torch.testing.assert_close(output[TransitionKey.ACTION], torch.tensor([[0.0, 5.0, 10.0]])) + + +def test_groot_n1_7_action_decode_applies_named_libero_transform_from_modality_key(): + unit_stats = { + "min": [0.0], + "max": [1.0], + "mean": [0.5], + "std": [1.0], + "q01": [0.0], + "q99": [1.0], + } + step = GrootN17ActionDecodeStep( + env_action_dim=3, + raw_stats={ + "action": { + "x": unit_stats, + "gripper": unit_stats, + "y": unit_stats, + } + }, + modality_config={ + "action": { + "modality_keys": ["x", "gripper", "y"], + "action_configs": [{}, {}, {}], + } + }, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + ) + action = torch.tensor( + [ + [ + [-1.0, -1.0, 1.0], + [1.0, 0.0, -1.0], + [0.0, 1.0, 0.0], + ] + ] + ) + + output = step({TransitionKey.ACTION: action}) + + expected = torch.tensor( + [ + [ + [0.0, 1.0, 1.0], + [1.0, -0.0, 0.0], + [0.5, -1.0, 0.5], + ] + ] + ) + torch.testing.assert_close(output[TransitionKey.ACTION], expected) + + +def test_groot_n1_7_action_decode_requires_gripper_key_for_libero_transform(): + step = GrootN17ActionDecodeStep( + env_action_dim=1, + raw_stats={ + "action": { + "x": { + "min": [0.0], + "max": [1.0], + }, + } + }, + modality_config={ + "action": { + "modality_keys": ["x"], + "action_configs": [{}], + } + }, + action_decode_transform=GROOT_ACTION_DECODE_TRANSFORM_LIBERO, + ) + + with pytest.raises(KeyError, match="gripper"): + step({TransitionKey.ACTION: torch.zeros(1, 1, 1)}) + + +def test_groot_n1_7_postprocessor_converts_libero_gripper_convention(): + step = GrootActionUnpackUnnormalizeStep( + env_action_dim=7, + normalize_min_max=True, + stats={ + ACTION: { + "min": [0.0] * 7, + "max": [1.0] * 7, + } + }, + libero_gripper_action=True, + ) + transition = { + TransitionKey.ACTION: torch.tensor( + [ + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ] + ) + } + + output = step(transition) + + torch.testing.assert_close(output[TransitionKey.ACTION][:, -1], torch.tensor([1.0, -1.0])) + + +def test_groot_n1_7_postprocessor_decodes_selected_action_and_gripper_thresholds(): + step = GrootActionUnpackUnnormalizeStep( + env_action_dim=7, + normalize_min_max=True, + clip_normalized_action=True, + stats={ + ACTION: { + "min": [0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 0.0], + "max": [2.0, 14.0, 26.0, 38.0, 50.0, 62.0, 1.0], + } + }, + libero_gripper_action=True, + ) + selected_actions = torch.tensor( + [ + [-1.0, -0.5, 0.0, 0.5, 1.0, 2.0, -0.5], + [-1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 0.0], + [-1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 0.5], + ] + ) + + output = step({TransitionKey.ACTION: selected_actions}) + + expected_prefix = torch.tensor([0.0, 11.0, 23.0, 36.0, 50.0, 62.0]) + torch.testing.assert_close(output[TransitionKey.ACTION][:, :6], expected_prefix.expand(3, 6)) + torch.testing.assert_close(output[TransitionKey.ACTION][:, -1], torch.tensor([1.0, -0.0, -1.0])) + + +def test_groot_n1_7_postprocessor_decodes_action_chunks_without_dropping_timesteps(): + step = GrootActionUnpackUnnormalizeStep( + env_action_dim=7, + normalize_min_max=True, + clip_normalized_action=True, + stats={ + ACTION: { + "min": [0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 0.0], + "max": [2.0, 14.0, 26.0, 38.0, 50.0, 62.0, 1.0], + } + }, + libero_gripper_action=True, + ) + action_chunk = torch.tensor( + [ + [ + [-1.0, 0.0, 1.0, -0.5, 0.5, 2.0, -1.0, 99.0], + [0.25, -0.25, 0.75, -0.75, 1.0, -1.0, 0.0, 99.0], + [1.0, -1.0, 0.0, 0.5, -0.5, 0.0, 0.5, 99.0], + ] + ] + ) + + output = step({TransitionKey.ACTION: action_chunk}) + + expected_prefix = torch.tensor( + [ + [ + [0.0, 12.0, 26.0, 32.0, 47.5, 62.0], + [1.25, 11.5, 25.25, 31.0, 50.0, 50.0], + [2.0, 10.0, 23.0, 36.0, 42.5, 56.0], + ] + ] + ) + assert output[TransitionKey.ACTION].shape == (1, 3, 7) + torch.testing.assert_close(output[TransitionKey.ACTION][..., :6], expected_prefix) + torch.testing.assert_close(output[TransitionKey.ACTION][..., -1], torch.tensor([[1.0, -0.0, -1.0]])) + + +def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config(tmp_path): + model_path = tmp_path / "local-checkpoint" + model_path.mkdir() + (model_path / "config.json").write_text('{"model_type": "Gr00tN1d7"}') + config = _groot_config(GROOT_N1_5) + + with pytest.raises(ValueError, match="does not match base_model_path"): + GrootPolicy.from_pretrained(model_path, config=config) + + +def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t(): + sys.modules.pop("gr00t", None) + config = _groot_config(GROOT_N1_7) + + preprocessor, _ = make_groot_pre_post_processors(config) + step_types = {type(step) for step in preprocessor.steps} + + assert GrootN17PackInputsStep in step_types + assert GrootN17VLMEncodeStep in step_types + assert GrootEagleEncodeStep not in step_types + assert "gr00t" not in sys.modules + + +def test_groot_n1_5_processors_still_use_eagle_path(): + config = _groot_config(GROOT_N1_5) + + preprocessor, _ = make_groot_pre_post_processors(config) + step_types = {type(step) for step in preprocessor.steps} + + assert GrootEagleEncodeStep in step_types + assert GrootN17VLMEncodeStep not in step_types + + +def test_groot_n1_7_pack_inputs_preserves_per_sample_language(): + step = GrootN17PackInputsStep( + action_horizon=2, + max_state_dim=4, + max_action_dim=3, + formalize_language=True, + normalize_min_max=False, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + }, + TransitionKey.ACTION: torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]), + TransitionKey.COMPLEMENTARY_DATA: { + "task": ["Pick Red Block!", "Place Blue Cube."], + }, + } + + output = step(transition) + + assert output[TransitionKey.COMPLEMENTARY_DATA]["language"] == [ + "pick red block", + "place blue cube", + ] + torch.testing.assert_close( + output[TransitionKey.OBSERVATION]["state"][:, 0, :2], + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + ) + + +def test_groot_n1_7_language_formalization_preserves_core_task_identifier_and_batch(): + step = GrootN17PackInputsStep( + action_horizon=2, + max_state_dim=8, + max_action_dim=7, + formalize_language=True, + normalize_min_max=False, + ) + transition = { + TransitionKey.OBSERVATION: { + OBS_STATE: torch.zeros(2, 8), + }, + TransitionKey.COMPLEMENTARY_DATA: { + "task": [ + "Pick_Up_The_Black_Bowl_Next_To_The_Ramekin_And_Place_It_On_The_Plate!!!", + "MOVE, the YELLOW mug -- to Zone_2.", + ], + }, + } + + output = step(transition) + + assert output[TransitionKey.COMPLEMENTARY_DATA]["language"] == [ + "pick_up_the_black_bowl_next_to_the_ramekin_and_place_it_on_the_plate", + "move the yellow mug to zone_2", + ] + + +def test_groot_n1_7_vlm_encode_uses_per_sample_language(): + class FakeProcessor: + def __init__(self): + self.rendered_texts = [] + self.encoded_texts = None + + def apply_chat_template(self, conversation, tokenize, add_generation_prompt): + text = conversation[0]["content"][-1]["text"] + self.rendered_texts.append(text) + return f"rendered:{text}" + + def __call__(self, text, images, return_tensors, padding): + self.encoded_texts = text + return { + "input_ids": torch.arange(len(text)).view(len(text), 1), + "attention_mask": torch.ones(len(text), 1, dtype=torch.long), + } + + fake_proc = FakeProcessor() + step = GrootN17VLMEncodeStep() + step._proc = fake_proc + transition = { + TransitionKey.OBSERVATION: { + "video": np.zeros((2, 1, 1, 2, 2, 3), dtype=np.uint8), + }, + TransitionKey.COMPLEMENTARY_DATA: { + "language": ["first task", "second task"], + }, + } + + output = step(transition) + + assert fake_proc.rendered_texts == ["first task", "second task"] + assert fake_proc.encoded_texts == ["rendered:first task", "rendered:second task"] + assert "video" not in output[TransitionKey.OBSERVATION] + torch.testing.assert_close( + output[TransitionKey.COMPLEMENTARY_DATA]["input_ids"], + torch.tensor([[0], [1]]), + ) + + +def test_groot_n1_7_vlm_encode_packs_images_time_major_then_camera_order(): + class FakeProcessor: + def __init__(self): + self.add_generation_prompts = [] + self.conversation_image_values = [] + self.conversation_texts = [] + self.encoded_texts = None + self.encoded_image_values = None + + def apply_chat_template(self, conversation, tokenize, add_generation_prompt): + assert tokenize is False + self.add_generation_prompts.append(add_generation_prompt) + content = conversation[0]["content"] + self.conversation_image_values.append( + [int(np.asarray(item["image"])[0, 0, 0]) for item in content if item["type"] == "image"] + ) + text = content[-1]["text"] + self.conversation_texts.append(text) + return f"rendered:{text}" + + def __call__(self, text, images, return_tensors, padding): + assert return_tensors == "pt" + assert padding is True + self.encoded_texts = text + self.encoded_image_values = [int(np.asarray(image)[0, 0, 0]) for image in images] + return { + "input_ids": torch.arange(len(text)).view(len(text), 1), + "attention_mask": torch.ones(len(text), 1, dtype=torch.long), + "pixel_values": torch.arange(len(images)).view(len(images), 1), + "image_grid_thw": torch.ones(len(images), 3, dtype=torch.long), + } + + fake_proc = FakeProcessor() + step = GrootN17VLMEncodeStep() + step._proc = fake_proc + video = np.zeros((2, 2, 2, 2, 2, 3), dtype=np.uint8) + image_id = 1 + for batch_idx in range(2): + for timestep in range(2): + for view_idx in range(2): + video[batch_idx, timestep, view_idx, :, :, :] = image_id + image_id += 1 + transition = { + TransitionKey.OBSERVATION: {"video": video}, + TransitionKey.COMPLEMENTARY_DATA: {"language": ["task a", "task b"]}, + } + + output = step(transition) + + assert fake_proc.conversation_image_values == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert fake_proc.encoded_image_values == [1, 2, 3, 4, 5, 6, 7, 8] + assert fake_proc.conversation_texts == ["task a", "task b"] + assert fake_proc.encoded_texts == ["rendered:task a", "rendered:task b"] + assert fake_proc.add_generation_prompts == [False, False] + assert "video" not in output[TransitionKey.OBSERVATION] + assert set(output[TransitionKey.COMPLEMENTARY_DATA]) >= { + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + } + + +def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path(): + cv2 = pytest.importorskip("cv2", exc_type=ImportError) + from PIL import Image + + image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3) + + transformed = _transform_n1_7_image_for_vlm( + Image.fromarray(image_np), + image_crop_size=[230, 230], + image_target_size=[256, 256], + shortest_image_edge=256, + crop_fraction=0.95, + use_albumentations=True, + ) + + expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA) + crop_edge = int(256 * 0.95) + crop_start = (256 - crop_edge) // 2 + expected = expected[crop_start : crop_start + crop_edge, crop_start : crop_start + crop_edge] + expected = cv2.resize(expected, (256, 256), interpolation=cv2.INTER_AREA) + + assert transformed.size == (256, 256) + np.testing.assert_array_equal(np.asarray(transformed), expected) + + +def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core_albumentations(): + cv2 = pytest.importorskip("cv2", exc_type=ImportError) + + class FakeProcessor: + def __init__(self): + self.images = None + + def apply_chat_template(self, conversation, tokenize, add_generation_prompt): + return conversation[0]["content"][-1]["text"] + + def __call__(self, text, images, return_tensors, padding): + self.images = images + return { + "input_ids": torch.ones(len(text), 1, dtype=torch.long), + "attention_mask": torch.ones(len(text), 1, dtype=torch.long), + } + + camera_a = np.arange(3 * 5 * 3, dtype=np.uint8).reshape(3, 5, 3) + camera_b = (np.arange(3 * 5 * 3, dtype=np.uint16).reshape(3, 5, 3) * 3 % 251).astype(np.uint8) + video = np.stack([camera_a, camera_b], axis=0).reshape(1, 1, 2, 3, 5, 3) + fake_proc = FakeProcessor() + step = GrootN17VLMEncodeStep( + image_target_size=[8, 8], + shortest_image_edge=10, + crop_fraction=0.6, + use_albumentations=True, + ) + step._proc = fake_proc + + step( + { + TransitionKey.OBSERVATION: {"video": video}, + TransitionKey.COMPLEMENTARY_DATA: {"language": ["move"]}, + } + ) + + assert fake_proc.images is not None + assert len(fake_proc.images) == 2 + np.testing.assert_array_equal( + np.asarray(fake_proc.images[0]), + _expected_albumentations_eval_image( + camera_a, + cv2, + target_size=[8, 8], + shortest_edge=10, + crop_fraction=0.6, + ), + ) + np.testing.assert_array_equal( + np.asarray(fake_proc.images[1]), + _expected_albumentations_eval_image( + camera_b, + cv2, + target_size=[8, 8], + shortest_edge=10, + crop_fraction=0.6, + ), + ) + + +def test_groot_n1_7_vlm_encode_config_round_trips_model_name(): + step = GrootN17VLMEncodeStep( + model_name="local-cosmos", + image_crop_size=[230, 230], + image_target_size=[256, 256], + shortest_image_edge=256, + crop_fraction=0.95, + use_albumentations=True, + ) + + restored = GrootN17VLMEncodeStep(**step.get_config()) + + assert restored.model_name == "local-cosmos" + assert restored.image_crop_size == [230, 230] + assert restored.image_target_size == [256, 256] + assert restored.shortest_image_edge == 256 + assert restored.crop_fraction == 0.95 + assert restored.use_albumentations is True + + +def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch): + pytest.importorskip("transformers") + + import transformers + + from lerobot.policies.groot import processor_groot + + calls = [] + + class FakeTokenizer: + chat_template = "fake-chat-template" + padding_side = "right" + + @classmethod + def from_pretrained(cls, model_name, **kwargs): + calls.append(("tokenizer", model_name, kwargs)) + return cls() + + class FakeImageProcessor: + @classmethod + def from_pretrained(cls, model_name, **kwargs): + calls.append(("image_processor", model_name, kwargs)) + return cls() + + class FakeVideoProcessor: + @classmethod + def from_pretrained(cls, model_name, **kwargs): + calls.append(("video_processor", model_name, kwargs)) + return cls() + + class FakeProcessor: + from_pretrained_called = False + + def __init__(self, *, image_processor, tokenizer, video_processor, chat_template): + self.image_processor = image_processor + self.tokenizer = tokenizer + self.video_processor = video_processor + self.chat_template = chat_template + + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.from_pretrained_called = True + raise AssertionError("Cosmos does not publish processor_config.json") + + monkeypatch.setattr(transformers, "AutoTokenizer", FakeTokenizer) + monkeypatch.setattr(transformers, "Qwen2VLImageProcessorFast", FakeImageProcessor) + monkeypatch.setattr(transformers, "Qwen3VLVideoProcessor", FakeVideoProcessor) + monkeypatch.setattr(transformers, "Qwen3VLProcessor", FakeProcessor) + + processor = processor_groot._build_n1_7_processor("nvidia/Cosmos-Reason2-2B") + + assert [call[:2] for call in calls] == [ + ("tokenizer", "nvidia/Cosmos-Reason2-2B"), + ("image_processor", "nvidia/Cosmos-Reason2-2B"), + ("video_processor", "nvidia/Cosmos-Reason2-2B"), + ] + assert all(call[2] == {"trust_remote_code": True} for call in calls) + assert processor.tokenizer.padding_side == "left" + assert processor.chat_template == "fake-chat-template" + assert not FakeProcessor.from_pretrained_called + + +def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path): + config = _groot_config(GROOT_N1_7) + dataset_stats = { + OBS_STATE: { + "min": torch.zeros(8), + "max": torch.ones(8), + }, + ACTION: { + "min": torch.zeros(7), + "max": torch.ones(7), + }, + } + preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=dataset_stats) + preprocessor.save_pretrained(tmp_path) + postprocessor.save_pretrained(tmp_path) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(tmp_path), + dataset_stats=dataset_stats, + ) + + pack_step = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + unpack_step = loaded_postprocessor.steps[0] + assert pack_step.normalize_min_max + torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"]) + torch.testing.assert_close(pack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"]) + torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], dataset_stats[OBS_STATE]["min"]) + torch.testing.assert_close(unpack_step.stats[ACTION]["max"], dataset_stats[ACTION]["max"]) + assert unpack_step.env_action_dim == 7 + + +def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stats(tmp_path): + config = _groot_config(GROOT_N1_7) + saved_stats = { + OBS_STATE: { + "min": torch.full((8,), -2.0), + "max": torch.full((8,), 2.0), + }, + ACTION: { + "min": torch.full((7,), -3.0), + "max": torch.full((7,), 3.0), + }, + } + preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=saved_stats) + preprocessor.save_pretrained(tmp_path) + postprocessor.save_pretrained(tmp_path) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(tmp_path), + ) + + pack_step = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep)) + unpack_step = loaded_postprocessor.steps[0] + assert pack_step.normalize_min_max + torch.testing.assert_close(pack_step.stats[OBS_STATE]["min"], saved_stats[OBS_STATE]["min"]) + torch.testing.assert_close(pack_step.stats[ACTION]["max"], saved_stats[ACTION]["max"]) + torch.testing.assert_close(unpack_step.stats[OBS_STATE]["min"], saved_stats[OBS_STATE]["min"]) + torch.testing.assert_close(unpack_step.stats[ACTION]["max"], saved_stats[ACTION]["max"]) + assert unpack_step.env_action_dim == 7 + + +def test_groot_policy_selects_n1_7_model_class(monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + called = {} + + def fake_from_pretrained(cls, **kwargs): + called.update(kwargs) + return _DummyGrootModel() + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(fake_from_pretrained)) + + policy = GrootPolicy(_groot_config(GROOT_N1_7)) + + assert called["pretrained_model_name_or_path"] == GROOT_N1_7_BASE_MODEL + assert isinstance(policy._groot_model, _DummyGrootModel) + + +def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + dummy_model = _DummyGrootModel() + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) + policy = GrootPolicy(_groot_config(GROOT_N1_7)) + + batch = { + "state": torch.zeros(2, 1, 132), + "action": torch.zeros(2, 40, 132), + "action_mask": torch.ones(2, 40, 132), + "embodiment_id": torch.zeros(2, dtype=torch.long), + "input_ids": torch.ones(2, 8, dtype=torch.long), + "attention_mask": torch.ones(2, 8, dtype=torch.long), + "pixel_values": torch.zeros(4, 3, 16, 16), + "image_grid_thw": torch.ones(4, 3, dtype=torch.long), + "mm_token_type_ids": torch.zeros(2, 8, dtype=torch.int32), + "pixel_values_videos": torch.zeros(1, 3, 16, 16), + "video_grid_thw": torch.ones(1, 3, dtype=torch.long), + "next.state": torch.ones(2, 1, 132), + "info": {"ignored": True}, + } + + loss, metrics = policy.forward(batch) + + assert loss.item() == pytest.approx(1.0) + assert metrics == {"loss": pytest.approx(1.0)} + assert set(dummy_model.forward_inputs) == { + "state", + "action", + "action_mask", + "embodiment_id", + "input_ids", + "attention_mask", + "pixel_values", + "image_grid_thw", + "mm_token_type_ids", + "pixel_values_videos", + "video_grid_thw", + } + + +def test_groot_n1_7_libero_execution_horizon_uses_core_eight_action_cadence(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + + assert infer_groot_n1_7_action_horizon(model_path, "libero_sim") == 16 + assert infer_groot_n1_7_action_execution_horizon(model_path, "libero_sim") == 8 + + +def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkeypatch): + from lerobot.policies.groot.groot_n1_7 import GR00TN17 + + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + + class HorizonModel(_DummyGrootModel): + def get_action(self, inputs): + assert inputs["action_mask"].shape == (1, 40) + assert inputs["action_mask"][0, :16].sum().item() == 16 + assert inputs["action_mask"][0, 16:].sum().item() == 0 + batch_size = inputs["state"].shape[0] + steps = torch.arange(40, dtype=torch.float32).view(1, 40, 1).expand(batch_size, 40, 132) + return {"action_pred": steps} + + monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel())) + input_features, output_features = _groot_features(state_dim=8, action_dim=7) + config = GrootConfig( + model_version=GROOT_N1_7, + base_model_path=str(model_path), + embodiment_tag="libero_sim", + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + n_action_steps=40, + ) + policy = GrootPolicy(config) + batch = { + "state": torch.zeros(1, 1, 132), + "embodiment_id": torch.zeros(1, dtype=torch.long), + "input_ids": torch.ones(1, 2, dtype=torch.long), + "attention_mask": torch.ones(1, 2, dtype=torch.long), + "pixel_values": torch.zeros(1, 3, 2, 2), + "image_grid_thw": torch.ones(1, 3, dtype=torch.long), + "action_mask": torch.cat((torch.ones(1, 16), torch.zeros(1, 24)), dim=1), + } + + first_action = policy.select_action(batch) + + assert policy._action_queue_steps == 8 + assert len(policy._action_queue) == 7 + torch.testing.assert_close(first_action[0, 0], torch.tensor(0.0)) + + for expected_step in range(1, 8): + action = policy.select_action(batch) + torch.testing.assert_close(action[0, 0], torch.tensor(float(expected_step))) + + refreshed_action = policy.select_action(batch) + torch.testing.assert_close(refreshed_action[0, 0], torch.tensor(0.0)) + + +def test_qwen3_backbone_uses_nested_transformers_model_contract(monkeypatch): + pytest.importorskip("transformers") + from transformers.feature_extraction_utils import BatchFeature + + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + class FakeLanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(1, 1) for _ in range(3)]) + + class FakeVisual(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(1, 1) + + class FakeInnerModel(nn.Module): + def __init__(self): + super().__init__() + self.language_model = FakeLanguageModel() + self.visual = FakeVisual() + + class FakeQwenForConditionalGeneration(nn.Module): + config = SimpleNamespace(image_token_id=42) + + def __init__(self): + super().__init__() + self.model = FakeInnerModel() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + return cls() + + def eval(self): + super().eval() + return self + + def forward(self, **kwargs): + batch_size, sequence_length = kwargs["input_ids"].shape + features = torch.arange(batch_size * sequence_length * 4, dtype=torch.float32).view( + batch_size, sequence_length, 4 + ) + return SimpleNamespace(hidden_states=[features, features + 1]) + + monkeypatch.setattr( + groot_n1_7, + "Qwen3VLForConditionalGeneration", + FakeQwenForConditionalGeneration, + ) + backbone = groot_n1_7.Qwen3Backbone( + model_name="fake-qwen", + select_layer=2, + tune_llm=False, + tune_visual=False, + use_flash_attention=False, + ) + + assert not hasattr(backbone.model, "language_model") + assert len(backbone.language_model.layers) == 2 + assert not any(parameter.requires_grad for parameter in backbone.language_model.parameters()) + assert not any(parameter.requires_grad for parameter in backbone.visual.parameters()) + + output = backbone.forward( + BatchFeature( + data={ + "input_ids": torch.tensor([[1, 42, 2], [42, 3, 4]]), + "attention_mask": torch.tensor([[1, 1, 0], [1, 1, 1]]), + "pixel_values": torch.zeros(2, 3, 2, 2), + "image_grid_thw": torch.ones(2, 3, dtype=torch.long), + } + ) + ) + + assert output["backbone_features"].shape == (2, 3, 4) + torch.testing.assert_close( + output["image_mask"], + torch.tensor([[False, True, False], [True, False, False]]), + ) + torch.testing.assert_close( + output["backbone_attention_mask"], + torch.tensor([[True, True, False], [True, True, True]]), + ) + + +def test_qwen3_backbone_can_initialize_from_config_without_downloading_weights(monkeypatch): + pytest.importorskip("transformers") + + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + class FakeLanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(1, 1) for _ in range(3)]) + + class FakeVisual(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(1, 1) + + class FakeInnerModel(nn.Module): + def __init__(self): + super().__init__() + self.language_model = FakeLanguageModel() + self.visual = FakeVisual() + + class FakeQwenForConditionalGeneration(nn.Module): + config = SimpleNamespace(image_token_id=42) + from_pretrained_called = False + from_config_called = False + + def __init__(self): + super().__init__() + self.model = FakeInnerModel() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + cls.from_pretrained_called = True + raise AssertionError("Qwen backbone weights should not be loaded separately") + + @classmethod + def _from_config(cls, config, **kwargs): + cls.from_config_called = True + return cls() + + def eval(self): + super().eval() + return self + + monkeypatch.setattr(groot_n1_7, "Qwen3VLForConditionalGeneration", FakeQwenForConditionalGeneration) + + backbone = groot_n1_7.Qwen3Backbone( + model_name="nvidia/Cosmos-Reason2-2B", + select_layer=2, + load_pretrained_weights=False, + ) + + assert isinstance(backbone.model, FakeQwenForConditionalGeneration) + assert FakeQwenForConditionalGeneration.from_config_called + assert not FakeQwenForConditionalGeneration.from_pretrained_called + + +def test_gr00t_n1_7_from_pretrained_defers_backbone_weight_loading(monkeypatch, tmp_path): + from huggingface_hub.errors import HFValidationError + + import lerobot.policies.groot.groot_n1_7 as groot_n1_7 + + called = {} + + class FakeLoadedModel: + def __init__(self): + self.config = SimpleNamespace(tune_top_llm_layers=0) + self.backbone = SimpleNamespace(set_trainable_parameters=lambda **kwargs: None) + self.action_head = SimpleNamespace(set_trainable_parameters=lambda **kwargs: None) + + def fake_snapshot_download(*args, **kwargs): + raise HFValidationError("local path") + + def fake_super_from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + called["pretrained_model_name_or_path"] = pretrained_model_name_or_path + called.update(kwargs) + return FakeLoadedModel() + + monkeypatch.setattr(groot_n1_7, "snapshot_download", fake_snapshot_download) + monkeypatch.setattr( + groot_n1_7.PreTrainedModel, + "from_pretrained", + classmethod(fake_super_from_pretrained), + ) + + loaded = groot_n1_7.GR00TN17.from_pretrained(str(tmp_path)) + + assert isinstance(loaded, FakeLoadedModel) + assert called["pretrained_model_name_or_path"] == str(tmp_path) + assert called["load_backbone_weights"] is False + + +def test_gr00t_n1_7_action_head_meta_init_defers_beta_distribution(): + pytest.importorskip("diffusers") + + from lerobot.policies.groot.groot_n1_7 import GR00TN17ActionHead, GR00TN17Config + + config = GR00TN17Config( + backbone_embedding_dim=32, + hidden_size=32, + input_embedding_dim=32, + max_state_dim=7, + max_action_dim=5, + action_horizon=4, + state_history_length=1, + max_num_embodiments=4, + use_alternate_vl_dit=False, + use_vlln=False, + add_pos_embed=False, + vl_self_attention_cfg={"num_layers": 0}, + diffusion_model_cfg={ + "positional_embeddings": None, + "num_layers": 1, + "num_attention_heads": 2, + "attention_head_dim": 16, + "norm_type": "ada_norm", + "dropout": 0.0, + "final_dropout": False, + "output_dim": 32, + "interleave_self_attention": False, + }, + ) + + with torch.device("meta"): + meta_action_head = GR00TN17ActionHead(config) + + assert meta_action_head._beta_dist is None + assert any(parameter.is_meta for parameter in meta_action_head.parameters()) + + action_head = GR00TN17ActionHead(config) + sample = action_head.sample_time(batch_size=3, device=torch.device("cpu"), dtype=torch.float32) + + assert action_head._beta_dist is not None + assert sample.shape == (3,) + assert torch.isfinite(sample).all() + + +def test_gr00t_n1_7_model_forward_with_mocked_backbone(): + pytest.importorskip("diffusers") + pytest.importorskip("transformers") + + from transformers.feature_extraction_utils import BatchFeature + + from lerobot.policies.groot.groot_n1_7 import GR00TN17, GR00TN17Config + + config = GR00TN17Config( + backbone_embedding_dim=32, + hidden_size=32, + input_embedding_dim=32, + max_state_dim=7, + max_action_dim=5, + action_horizon=4, + state_history_length=1, + num_inference_timesteps=2, + max_num_embodiments=4, + use_alternate_vl_dit=False, + use_vlln=True, + vl_self_attention_cfg={"num_layers": 0}, + state_dropout_prob=0.0, + diffusion_model_cfg={ + "positional_embeddings": None, + "num_layers": 1, + "num_attention_heads": 2, + "attention_head_dim": 16, + "norm_type": "ada_norm", + "dropout": 0.0, + "final_dropout": False, + "output_dim": 32, + "interleave_self_attention": False, + }, + ) + + class MockBackbone(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.zeros(())) + + def prepare_input(self, inputs): + return BatchFeature(data=inputs) + + def forward(self, inputs): + batch_size = inputs["state"].shape[0] + return BatchFeature( + data={ + "backbone_features": torch.randn(batch_size, 3, config.backbone_embedding_dim), + "backbone_attention_mask": torch.ones(batch_size, 3, dtype=torch.bool), + "image_mask": torch.zeros(batch_size, 3, dtype=torch.bool), + } + ) + + def set_trainable_parameters(self, *args, **kwargs): + return None + + with patch( + "lerobot.policies.groot.groot_n1_7.get_backbone_cls", + return_value=lambda **kwargs: MockBackbone(), + ): + model = GR00TN17(config) + + inputs = { + "state": torch.randn(2, config.state_history_length, config.max_state_dim), + "action": torch.randn(2, config.action_horizon, config.max_action_dim), + "action_mask": torch.ones(2, config.action_horizon, config.max_action_dim), + "embodiment_id": torch.zeros(2, dtype=torch.long), + } + + output = model.forward(inputs) + assert output["loss"].dim() == 0 + assert torch.isfinite(output["loss"]) + + inference_inputs = {key: value for key, value in inputs.items() if key != "action"} + action_output = model.get_action(inference_inputs) + assert action_output["action_pred"].shape == (2, config.action_horizon, config.max_action_dim)