mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-13 14:39:44 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f81fc79564 | |||
| 9ce6633518 |
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
||||
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -31,12 +38,22 @@ Hugging Face Models:
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
|
||||
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
|
||||
byte-identical pre-processed inputs and the same flow-matching seed. It is
|
||||
parametrized over every embodiment tag present in the checkpoint.
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves the exact
|
||||
collated inputs + raw `action_pred` to one `.npz` per tag.
|
||||
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
|
||||
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
|
||||
seed, and asserts the outputs match.
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs** — the original processor's `input_ids`,
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
@@ -90,15 +119,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–9 MB each) and are regenerated by
|
||||
the producer; they are never committed. The test **skips** (does not fail) on CI or
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -42,6 +43,9 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -23,15 +24,29 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
@@ -41,7 +56,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
@@ -52,9 +72,10 @@ def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
)
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
@@ -286,6 +307,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
@@ -298,8 +321,17 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@@ -310,27 +342,32 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
max_state_dim: int = 132
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
max_action_dim: int = 132
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||
# Kept only so config.json files saved with earlier versions still parse.
|
||||
image_size: tuple[int, int] = (256, 256)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
@@ -344,7 +381,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
action_decode_transform: str | None = None
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -384,17 +428,13 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -405,6 +445,15 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||
# N1.5 checkpoint or config is being loaded.
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
@@ -416,26 +465,48 @@ class GrootConfig(PreTrainedConfig):
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
|
||||
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
if self.max_state_dim == 64:
|
||||
self.max_state_dim = 132
|
||||
if self.max_action_dim == 32:
|
||||
self.max_action_dim = 132
|
||||
if self.chunk_size == 50:
|
||||
self.chunk_size = 40
|
||||
if self.n_action_steps == 50:
|
||||
self.n_action_steps = 40
|
||||
if tuple(self.image_size) == (224, 224):
|
||||
self.image_size = (256, 256)
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
raise ValueError(
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
@@ -511,9 +582,22 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
"""Return indices for delta actions.
|
||||
|
||||
The model action horizon is read from the checkpoint's processor_config.json
|
||||
when available; the result is cached (keyed on the inputs that determine it) so
|
||||
repeated access during dataset/training setup does not re-read from disk.
|
||||
"""
|
||||
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||
return indices
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 12,
|
||||
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
@@ -819,11 +819,16 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if (
|
||||
config.backbone_model_type == "qwen"
|
||||
or "nvidia/Cosmos-Reason2" in config.model_name
|
||||
or "Qwen/Qwen3-VL" in config.model_name
|
||||
):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
|
||||
# marker, so trust the explicit backbone type but surface what is being assumed.
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
@@ -909,7 +914,11 @@ class GR00TN17(PreTrainedModel):
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("revision", "cache_dir", "local_files_only", "token"):
|
||||
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
|
||||
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
|
||||
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
|
||||
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
|
||||
@@ -18,15 +18,12 @@
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
@@ -50,6 +49,8 @@ from .configuration_groot import (
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
@@ -92,8 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||
# bf16 preference when it is enabled instead of reading a default back.
|
||||
if self.config.use_bf16:
|
||||
model.compute_dtype = "bfloat16"
|
||||
model.config.compute_dtype = "bfloat16"
|
||||
|
||||
return model
|
||||
|
||||
@@ -148,9 +152,10 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
print(
|
||||
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -181,7 +186,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -197,7 +202,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
@@ -229,10 +234,13 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
raise ValueError(
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -297,9 +305,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
@@ -320,9 +326,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
|
||||
)
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
@@ -331,9 +335,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
||||
)
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
@@ -346,7 +348,9 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
@@ -409,6 +413,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
@@ -471,33 +480,21 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
@@ -59,6 +61,7 @@ from lerobot.utils.constants import (
|
||||
|
||||
from .configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7_BACKBONE_MODEL,
|
||||
GrootConfig,
|
||||
is_raw_groot_n1_7_checkpoint,
|
||||
@@ -80,12 +83,6 @@ N1_7_EMBODIMENT_MAPPING = {
|
||||
"new_embodiment": 10,
|
||||
}
|
||||
|
||||
_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {}
|
||||
|
||||
|
||||
def _n1_7_state_cache_key(value: str | None) -> str:
|
||||
return value or "groot_n1_7_default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GrootN17CheckpointProcessorAssets:
|
||||
@@ -471,25 +468,98 @@ def _legacy_groot_processor_overrides(
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
|
||||
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
|
||||
|
||||
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
"""Check whether a serialized processor pipeline contains a registry step.
|
||||
|
||||
Resolves the processor config from a local directory or, for Hub repo ids,
|
||||
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
|
||||
False when the config cannot be resolved; loading then proceeds with the
|
||||
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
|
||||
without them if they do not match the serialized pipeline.
|
||||
"""
|
||||
path = Path(pretrained_path).expanduser()
|
||||
if not path.is_dir():
|
||||
if path.is_dir():
|
||||
config = _read_json(path / config_filename)
|
||||
elif path.exists():
|
||||
return False
|
||||
config = _read_json(path / config_filename)
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
config = _read_json(Path(config_path))
|
||||
steps = config.get("steps", [])
|
||||
if not isinstance(steps, list):
|
||||
return False
|
||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||
|
||||
|
||||
def _apply_groot_step_overrides(
|
||||
pipeline: PolicyProcessorPipeline,
|
||||
overrides: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Apply ``from_pretrained``-style step overrides to a freshly built pipeline.
|
||||
|
||||
Raw N1.7 checkpoints build their processors from scratch instead of
|
||||
deserializing them, so caller overrides must be applied to the constructed
|
||||
steps. Override keys match a step's registry name or, as a convenience, its
|
||||
class name (``PolicyProcessorPipeline.from_pretrained`` matches registered
|
||||
steps by registry name only — prefer registry names so overrides keep
|
||||
working after the checkpoint is converted and reloaded from a serialized
|
||||
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||
silently.
|
||||
"""
|
||||
|
||||
if not overrides:
|
||||
return
|
||||
|
||||
def _step_keys(step: ProcessorStep) -> set[str]:
|
||||
keys = {type(step).__name__}
|
||||
registry_name = getattr(type(step), "_registry_name", None)
|
||||
if registry_name:
|
||||
keys.add(registry_name)
|
||||
return keys
|
||||
|
||||
for override_key, step_overrides in overrides.items():
|
||||
matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)]
|
||||
if not matched_steps:
|
||||
available = [
|
||||
getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps
|
||||
]
|
||||
raise KeyError(
|
||||
f"Override key '{override_key}' does not match any step of the GR00T processor pipeline "
|
||||
f"built for this raw N1.7 checkpoint. Available step keys: {available}."
|
||||
)
|
||||
for step in matched_steps:
|
||||
if not is_dataclass(step):
|
||||
raise TypeError(
|
||||
f"Cannot apply overrides to step '{override_key}': it is not a dataclass step."
|
||||
)
|
||||
init_field_names = {f.name for f in fields(step) if f.init}
|
||||
for field_name, value in dict(step_overrides).items():
|
||||
if field_name not in init_field_names:
|
||||
raise TypeError(
|
||||
f"Override field '{field_name}' is not a config field of step '{override_key}'. "
|
||||
f"Available fields: {sorted(init_field_names)}."
|
||||
)
|
||||
setattr(step, field_name, value)
|
||||
# Re-derive attributes computed from the overridden config (e.g.
|
||||
# DeviceProcessorStep resolves its torch.device in __post_init__).
|
||||
post_init = getattr(step, "__post_init__", None)
|
||||
if callable(post_init):
|
||||
post_init()
|
||||
|
||||
|
||||
def make_groot_pre_post_processors_from_pretrained(
|
||||
config: GrootConfig,
|
||||
pretrained_path: str,
|
||||
@@ -508,12 +578,20 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
processor_cfg = copy(config)
|
||||
processor_cfg.base_model_path = str(pretrained_path)
|
||||
return make_groot_pre_post_processors(
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config=processor_cfg,
|
||||
dataset_stats=dataset_stats,
|
||||
)
|
||||
# Raw checkpoints have no serialized pipelines to load overrides into,
|
||||
# so apply the caller overrides (e.g. device and rename_map from
|
||||
# lerobot-eval or the policy server) to the freshly built steps.
|
||||
_apply_groot_step_overrides(preprocessor, preprocessor_overrides)
|
||||
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if _local_processor_config_has_step(
|
||||
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
if _pretrained_processor_config_has_step(
|
||||
pretrained_path,
|
||||
postprocessor_config_filename,
|
||||
"groot_n1_7_action_decode_v1",
|
||||
@@ -521,15 +599,55 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||
# action decoder. Adding the legacy action-unpack override would target
|
||||
# a step that is not present and break loading.
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
applied_legacy_overrides = False
|
||||
preprocessor_overrides = caller_preprocessor_overrides
|
||||
postprocessor_overrides = caller_postprocessor_overrides
|
||||
else:
|
||||
applied_legacy_overrides = True
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
try:
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
except KeyError:
|
||||
if not applied_legacy_overrides:
|
||||
raise
|
||||
# The legacy overrides target steps that are absent from the serialized
|
||||
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
|
||||
# config could not be inspected before loading); retry with the caller
|
||||
# overrides only.
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=caller_preprocessor_overrides,
|
||||
postprocessor_overrides=caller_postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def _load_groot_processor_pipelines(
|
||||
pretrained_path: str,
|
||||
*,
|
||||
preprocessor_overrides: dict[str, Any],
|
||||
postprocessor_overrides: dict[str, Any],
|
||||
preprocessor_config_filename: str,
|
||||
postprocessor_config_filename: str,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=preprocessor_config_filename,
|
||||
@@ -544,7 +662,6 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
@@ -564,6 +681,28 @@ def _reconnect_groot_relative_absolute_steps(
|
||||
step.relative_step = relative_step
|
||||
|
||||
|
||||
def _reconnect_groot_n1_7_pack_decode_steps(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> None:
|
||||
"""Re-link a deserialized N1.7 action decode step to its pack step.
|
||||
|
||||
The pack step holds the per-instance raw-state cache that relative-action
|
||||
decoding reads its reference state from; the link itself is not serialized.
|
||||
"""
|
||||
|
||||
pack_step = next(
|
||||
(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)),
|
||||
None,
|
||||
)
|
||||
if pack_step is None:
|
||||
return
|
||||
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
|
||||
step.pack_step = pack_step
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
@@ -607,9 +746,12 @@ def make_groot_pre_post_processors(
|
||||
else action_horizon
|
||||
)
|
||||
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
|
||||
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
|
||||
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
|
||||
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
|
||||
embodiment_mapping = (
|
||||
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
|
||||
checkpoint_assets.embodiment_mapping
|
||||
if checkpoint_assets is not None
|
||||
else dict(N1_7_EMBODIMENT_MAPPING)
|
||||
)
|
||||
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
|
||||
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
|
||||
@@ -618,7 +760,6 @@ def make_groot_pre_post_processors(
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
|
||||
pack_step = GrootN17PackInputsStep(
|
||||
state_horizon=1,
|
||||
action_horizon=action_horizon,
|
||||
@@ -636,7 +777,6 @@ def make_groot_pre_post_processors(
|
||||
video_modality_keys=video_modality_keys,
|
||||
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
||||
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
||||
state_cache_key=state_cache_key,
|
||||
)
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
@@ -647,14 +787,31 @@ def make_groot_pre_post_processors(
|
||||
model_name=config.n1_7_backbone_model,
|
||||
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
|
||||
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
|
||||
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
|
||||
shortest_image_edge=checkpoint_assets.shortest_image_edge
|
||||
if checkpoint_assets is not None
|
||||
else None,
|
||||
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
|
||||
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
|
||||
use_albumentations=checkpoint_assets.use_albumentations
|
||||
if checkpoint_assets is not None
|
||||
else False,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
if checkpoint_assets is None:
|
||||
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
|
||||
raise ValueError(
|
||||
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
|
||||
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
|
||||
"actions cannot be normalized or decoded. Pass dataset_stats, or set "
|
||||
"config.embodiment_tag to an embodiment present in the checkpoint's statistics.json."
|
||||
)
|
||||
if checkpoint_assets is None or not checkpoint_has_stats:
|
||||
# When the checkpoint sidecars have no stats for the configured
|
||||
# embodiment tag (e.g. finetuning a raw base checkpoint with the
|
||||
# default 'new_embodiment' tag), the pack step above normalized with
|
||||
# the dataset stats; the decode step must invert with the same stats
|
||||
# instead of using a checkpoint decoder whose empty stats would
|
||||
# silently return normalized [-1, 1] actions.
|
||||
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
|
||||
env_action_dim=env_action_dim,
|
||||
stats=padded_stats,
|
||||
@@ -669,7 +826,6 @@ def make_groot_pre_post_processors(
|
||||
use_percentiles=checkpoint_assets.use_percentiles,
|
||||
use_relative_action=checkpoint_assets.use_relative_action,
|
||||
pack_step=pack_step,
|
||||
state_cache_key=state_cache_key,
|
||||
action_decode_transform=config.action_decode_transform,
|
||||
)
|
||||
|
||||
@@ -770,7 +926,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
try:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Qwen2VLImageProcessorFast,
|
||||
Qwen2VLImageProcessor,
|
||||
Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor,
|
||||
)
|
||||
@@ -781,7 +937,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
) from exc
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True)
|
||||
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
proc = Qwen3VLProcessor(
|
||||
image_processor=image_processor,
|
||||
@@ -902,8 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
video_modality_keys: list[str] | None = None
|
||||
raw_stats: dict[str, Any] | None = None
|
||||
modality_config: dict[str, Any] | None = None
|
||||
# Unused: kept so serialized configs that include it still load. The raw
|
||||
# state cache is per instance (_last_raw_state), never process-global.
|
||||
state_cache_key: str = ""
|
||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
||||
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
||||
@@ -913,19 +1072,56 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
return sorted(available)
|
||||
|
||||
ordered: list[str] = []
|
||||
unmatched: list[str] = []
|
||||
for modality_key in self.video_modality_keys:
|
||||
candidates = [f"{OBS_IMAGES}.{modality_key}"]
|
||||
# Alias for datasets converted with generic camera names (e.g. the
|
||||
# LIBERO conversions expose the wrist camera as
|
||||
# `observation.images.image2`), so raw N1.7 LIBERO checkpoints
|
||||
# match those datasets out of the box.
|
||||
if modality_key == "wrist_image":
|
||||
candidates.append(f"{OBS_IMAGES}.image2")
|
||||
elif modality_key == "image":
|
||||
candidates.append(f"{OBS_IMAGES}.image")
|
||||
|
||||
match = next((candidate for candidate in candidates if candidate in available), None)
|
||||
if match is not None:
|
||||
if match is None:
|
||||
unmatched.append(modality_key)
|
||||
else:
|
||||
ordered.append(match)
|
||||
|
||||
if not ordered:
|
||||
if not self._warned_image_keys:
|
||||
self._warned_image_keys = True
|
||||
logging.warning(
|
||||
"None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; "
|
||||
"falling back to feeding all cameras in alphabetical order, which is unlikely to be "
|
||||
"the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via "
|
||||
"--rename_map) to match the checkpoint keys.",
|
||||
self.video_modality_keys,
|
||||
sorted(available),
|
||||
)
|
||||
return sorted(available)
|
||||
unused = sorted(available - set(ordered))
|
||||
if (unmatched or unused) and not self._warned_image_keys:
|
||||
self._warned_image_keys = True
|
||||
if unmatched:
|
||||
logging.warning(
|
||||
"GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; "
|
||||
"the model will receive %d view(s) instead of the %d it was trained with. Rename "
|
||||
"the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.",
|
||||
unmatched,
|
||||
sorted(available),
|
||||
len(ordered),
|
||||
len(self.video_modality_keys),
|
||||
self.video_modality_keys,
|
||||
)
|
||||
if unused:
|
||||
logging.warning(
|
||||
"Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality "
|
||||
"keys %s, which matched %s.",
|
||||
unused,
|
||||
self.video_modality_keys,
|
||||
ordered,
|
||||
)
|
||||
return ordered
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -988,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
start_idx += dim
|
||||
if grouped:
|
||||
self._last_raw_state = grouped
|
||||
_N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped
|
||||
|
||||
img_keys = self._ordered_image_keys(obs)
|
||||
if img_keys:
|
||||
@@ -1101,7 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
"video_modality_keys": self.video_modality_keys,
|
||||
"raw_stats": self.raw_stats,
|
||||
"modality_config": self.modality_config,
|
||||
"state_cache_key": self.state_cache_key,
|
||||
}
|
||||
|
||||
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
||||
@@ -1223,6 +1417,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"use_albumentations": self.use_albumentations,
|
||||
}
|
||||
|
||||
|
||||
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||
value = entry.get(stat_name)
|
||||
@@ -1351,6 +1546,18 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
group with the checkpoint stats, converts relative groups to absolute values
|
||||
using the raw state cached during packing, concatenates groups in checkpoint
|
||||
order, and finally slices to the environment action dimension.
|
||||
|
||||
Relative-action decoding reads the reference state from the connected
|
||||
``pack_step`` (re-linked after ``from_pretrained`` by
|
||||
``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the
|
||||
most recent preprocess call. Engines that decode the whole chunk right
|
||||
after prediction (RTC, async policy server) therefore use the
|
||||
prediction-time state, matching Isaac-GR00T. The sync per-step queue path
|
||||
instead decodes each popped (B, D) action against the latest observation:
|
||||
the reference can be newer than the observation the chunk was predicted
|
||||
from, and per-timestep relative stats are applied as if the popped action
|
||||
were chunk step 0. Fixing that would require carrying the reference state
|
||||
and chunk index alongside each queued action through the postprocessor.
|
||||
"""
|
||||
|
||||
env_action_dim: int = 0
|
||||
@@ -1358,6 +1565,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
modality_config: dict[str, Any] | None = None
|
||||
use_percentiles: bool = False
|
||||
use_relative_action: bool = False
|
||||
# Unused: kept so serialized configs that include it still load.
|
||||
state_cache_key: str = ""
|
||||
action_decode_transform: str | None = None
|
||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||
@@ -1378,6 +1586,12 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
action_np = action.detach().cpu().float().numpy()
|
||||
# The sync action queue postprocesses popped actions as (B, D); decode
|
||||
# them as single-step (B, 1, D) chunks and squeeze the horizon back at
|
||||
# the end so both ranks share the chunk decode logic below.
|
||||
squeeze_horizon = action_np.ndim == 2
|
||||
if squeeze_horizon:
|
||||
action_np = action_np[:, None, :]
|
||||
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
|
||||
if valid_horizon is not None:
|
||||
action_np = action_np[:, :valid_horizon]
|
||||
@@ -1405,17 +1619,24 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
use_relative_action=self.use_relative_action,
|
||||
use_percentiles=self.use_percentiles,
|
||||
)
|
||||
# Per-timestep stats carry one row per chunk step; align them with
|
||||
# the decoded horizon (chunks always start at step 0, and a popped
|
||||
# (B, D) action is decoded as step 0).
|
||||
if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]:
|
||||
min_v = min_v[: normalized.shape[1]]
|
||||
max_v = max_v[: normalized.shape[1]]
|
||||
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
|
||||
start_idx += dim
|
||||
|
||||
if self.use_relative_action:
|
||||
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
||||
if raw_state is None:
|
||||
raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key))
|
||||
if raw_state is None:
|
||||
raise RuntimeError(
|
||||
"GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep "
|
||||
"to convert relative N1.7 actions back to absolute actions."
|
||||
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
||||
"GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. "
|
||||
"Build both pipelines through make_groot_pre_post_processors (or load them together "
|
||||
"via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an "
|
||||
"observation before decoding actions."
|
||||
)
|
||||
for idx, key in enumerate(action_keys):
|
||||
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
|
||||
@@ -1451,6 +1672,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
action_keys=action_keys,
|
||||
decoded_groups=decoded_groups,
|
||||
)
|
||||
if squeeze_horizon:
|
||||
decoded = decoded[:, 0]
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.ACTION] = torch.as_tensor(
|
||||
decoded, dtype=action.dtype, device=action.device
|
||||
@@ -1467,13 +1690,15 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
"modality_config": self.modality_config,
|
||||
"use_percentiles": self.use_percentiles,
|
||||
"use_relative_action": self.use_relative_action,
|
||||
"state_cache_key": self.state_cache_key,
|
||||
"action_decode_transform": self.action_decode_transform,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
|
||||
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
# Apply inverse of min-max normalization if it was used in preprocessor
|
||||
@@ -1585,3 +1810,37 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
|
||||
if reconstructed:
|
||||
self.stats = reconstructed
|
||||
|
||||
|
||||
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
|
||||
|
||||
Saved N1.5 checkpoints reference these registry names in their processor JSON
|
||||
files. Deserializing them must fail with the canonical N1.5 removal guidance
|
||||
instead of an opaque registry KeyError (or, for
|
||||
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
|
||||
action-chunk semantics changed).
|
||||
"""
|
||||
|
||||
@ProcessorStepRegistry.register(name=registry_name)
|
||||
class _RemovedGrootN15ProcessorStep(ProcessorStep):
|
||||
def __init__(self, **_kwargs: Any) -> None:
|
||||
raise ValueError(
|
||||
f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. "
|
||||
f"{GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_features(self, features):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for _removed_n1_5_step_name in (
|
||||
"groot_pack_inputs_v3",
|
||||
"groot_eagle_encode_v3",
|
||||
"groot_eagle_collate_v3",
|
||||
"groot_action_unpack_unnormalize_v1",
|
||||
):
|
||||
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)
|
||||
|
||||
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
assert isinstance(lerobot_loss, torch.Tensor)
|
||||
assert torch.isfinite(lerobot_loss).all()
|
||||
assert "loss" in lerobot_metrics
|
||||
assert np.isfinite(lerobot_metrics["loss"])
|
||||
|
||||
print("\nForward pass successful.")
|
||||
print(f" - Loss: {lerobot_loss.item():.6f}")
|
||||
print(f" - Metrics: {lerobot_metrics}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,31 +14,36 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
|
||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
||||
in the checkpoint.
|
||||
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
|
||||
once in the original ``gr00t`` env by the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
|
||||
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
||||
to per-tag ``.npz`` files.
|
||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
||||
model, and compares.
|
||||
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
|
||||
action head + Qwen3-VL backbone must produce the SAME raw model output
|
||||
(``action_pred``, the normalized flow-matching prediction before any action
|
||||
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
|
||||
pre-processed inputs and the flow-matching seed recorded in the artifact.
|
||||
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
|
||||
template / tokenizer / image packing + state normalization, no mocks) must produce
|
||||
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
|
||||
as the original package's processor, given the identical raw observations
|
||||
(images, state, language) recorded in the artifact. Artifacts written by older
|
||||
versions of the dump script carry no raw observations; this case then SKIPS with
|
||||
a regeneration hint.
|
||||
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
||||
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default they look for artifacts in
|
||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||
for the full run procedure.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
|
||||
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
|
||||
# Collated keys compared by the preprocessor parity case: integer/id tensors must
|
||||
# match exactly; float tensors within ATOL/RTOL.
|
||||
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
|
||||
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
|
||||
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
"""Directory holding the per-embodiment .npz artifacts.
|
||||
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
|
||||
return str(ckpt)
|
||||
|
||||
|
||||
def _load_artifact(path: Path):
|
||||
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
|
||||
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
if "seed" in data.files:
|
||||
seed = int(data["seed"])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
|
||||
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
|
||||
"regenerate the artifact with the current dump script.",
|
||||
stacklevel=2,
|
||||
)
|
||||
seed = SEED
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
inputs = {}
|
||||
for key in data.files:
|
||||
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
inputs[name] = t
|
||||
return original_action, inputs
|
||||
return original_action, inputs, seed
|
||||
|
||||
|
||||
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
|
||||
"""Return the raw observation recorded in the artifact, or None for old artifacts.
|
||||
|
||||
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
|
||||
exact raw observation the producer fed to the original processor: per-camera uint8
|
||||
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
|
||||
(``raw::state.<key>``, (B, T, dim)) and the language instruction
|
||||
(``raw::language``, one string per batch element). ``raw_video_keys`` /
|
||||
``raw_state_keys`` record the checkpoint modality-key order.
|
||||
"""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
|
||||
if any(marker not in data.files for marker in markers):
|
||||
return None
|
||||
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
|
||||
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
|
||||
return {
|
||||
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
|
||||
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
|
||||
"language": [str(t) for t in data["raw::language"].tolist()],
|
||||
}
|
||||
|
||||
|
||||
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert the producer's raw observation into a LeRobot policy batch."""
|
||||
batch: dict[str, Any] = {}
|
||||
for key, frames in raw["video"].items():
|
||||
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
|
||||
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
|
||||
# observation.state is the per-key state vectors (latest frame) concatenated in
|
||||
# checkpoint modality-key order -- the layout the LeRobot pack step and the
|
||||
# flattened checkpoint statistics expect.
|
||||
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
|
||||
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
|
||||
batch["task"] = list(raw["language"])
|
||||
return batch
|
||||
|
||||
|
||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
def _assert_collated_parity(
|
||||
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
|
||||
) -> None:
|
||||
"""Compare one collated tensor produced by LeRobot against the original's."""
|
||||
assert isinstance(lerobot_value, torch.Tensor), (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
|
||||
f"{type(lerobot_value).__name__}, expected a tensor."
|
||||
)
|
||||
lerobot_t = lerobot_value.detach().cpu()
|
||||
original_t = original_value.detach().cpu()
|
||||
assert lerobot_t.shape == original_t.shape, (
|
||||
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
|
||||
f"original={tuple(original_t.shape)}."
|
||||
)
|
||||
if exact:
|
||||
mismatched = int((lerobot_t.long() != original_t.long()).sum())
|
||||
assert mismatched == 0, (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
|
||||
f"{mismatched}/{original_t.numel()} elements mismatch."
|
||||
)
|
||||
else:
|
||||
lerobot_f, original_f = lerobot_t.float(), original_t.float()
|
||||
max_diff = (lerobot_f - original_f).abs().max().item()
|
||||
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
|
||||
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
@@ -165,8 +256,7 @@ def lerobot_model():
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_requires_artifacts = pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
|
||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
original_action, flat_inputs, seed = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
torch.manual_seed(SEED)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
with torch.inference_mode():
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
||||
original_action = original_action[:, :t, :d]
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
assert lerobot_action.shape == original_action.shape, (
|
||||
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
|
||||
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
|
||||
"The same checkpoint and inputs must produce identical shapes; this indicates an "
|
||||
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
|
||||
"utils/dump_original_n1_7.py)."
|
||||
)
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
max_diff = diff.max().item()
|
||||
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_preprocessor_parity(embodiment_tag, artifact):
|
||||
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
|
||||
|
||||
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
|
||||
template, tokenizer and image packing plus the checkpoint-driven state
|
||||
normalization (no mocks) -- on the raw observations recorded in the artifact, and
|
||||
compares every collated model input against the ones the original ``gr00t``
|
||||
processor produced from the same raw observations.
|
||||
"""
|
||||
raw = _load_raw_observation(artifact)
|
||||
if raw is None:
|
||||
pytest.skip(
|
||||
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
|
||||
"not record raw observations; regenerate it with the current dump script to run the "
|
||||
"preprocessor parity case."
|
||||
)
|
||||
_, flat_inputs, _ = _load_artifact(artifact)
|
||||
original_inputs = _unflatten(flat_inputs)
|
||||
|
||||
ckpt = _resolve_checkpoint()
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
|
||||
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
|
||||
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||
|
||||
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
|
||||
|
||||
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
|
||||
missing_original = [k for k in compared_keys if k not in original_inputs]
|
||||
missing_lerobot = [k for k in compared_keys if k not in processed]
|
||||
assert not missing_original, (
|
||||
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
|
||||
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
|
||||
)
|
||||
assert not missing_lerobot, (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
|
||||
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
|
||||
)
|
||||
|
||||
for name in compared_keys:
|
||||
_assert_collated_parity(
|
||||
embodiment_tag,
|
||||
name,
|
||||
processed[name],
|
||||
original_inputs[name],
|
||||
exact=name in _COLLATED_EXACT_KEYS,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
|
||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||
|
||||
* the RAW observation fed to the original processor (per-camera uint8 frames,
|
||||
per-key state vectors, the language instruction), so the LeRobot side can also
|
||||
run its OWN preprocessor on identical raw inputs and compare collated tensors,
|
||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||
* the random seed used right before the flow-matching sampler,
|
||||
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
|
||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||
``libero_sim``.
|
||||
|
||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
||||
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
|
||||
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
|
||||
(model parity), and the raw observation is replayed through LeRobot's own
|
||||
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
|
||||
|
||||
Usage:
|
||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
lang_key = modality_cfg["language"].modality_keys[0]
|
||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||
|
||||
# Snapshot the RAW observation exactly as fed to the original processor below. The
|
||||
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
|
||||
# and compares the resulting collated tensors against the "in::" ones saved further
|
||||
# down. raw_state_keys records the checkpoint modality-key order, which is the
|
||||
# concatenation order of the flat LeRobot ``observation.state`` vector.
|
||||
spec_keys = [key for key, _ in state_spec]
|
||||
state_modality = modality_cfg.get("state")
|
||||
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
|
||||
state_keys += [key for key in spec_keys if key not in state_keys]
|
||||
raw_language = [
|
||||
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
||||
for item in observation["language"][lang_key]
|
||||
]
|
||||
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
|
||||
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
|
||||
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
|
||||
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
|
||||
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
|
||||
|
||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||
policy.modality_configs = {
|
||||
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
embodiment_tag=np.array(tag),
|
||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||
**raw_flat,
|
||||
**flat,
|
||||
)
|
||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||
@@ -181,7 +203,12 @@ def main():
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
policy,
|
||||
fair_model,
|
||||
tag,
|
||||
all_modality[tag],
|
||||
state_spec,
|
||||
args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
|
||||
Reference in New Issue
Block a user