Compare commits

..

8 Commits

Author SHA1 Message Date
Steven Palma 226a4c5a8c fix test relative expr 2026-06-15 21:07:01 +02:00
Steven Palma 05a9ca274b relative experiment 2026-06-15 16:38:36 +02:00
Steven Palma 13ed657056 fix(groot): GPU/tensor N1.7 image preprocessing + resize to trained resolution
GR00T training was dataloader-bound (0->100->0 GPU-utilization sawtooth).
GrootN17VLMEncodeStep ran the Qwen3-VL image processor per frame on PIL images
on the single CPU main-loop thread, and that cost is timed inside dataloading_s
(preprocessor(batch) runs in the main process, not the dataloader workers), so
adding workers cannot hide it.

- Feed the torchvision-backed Qwen3-VL processor (C,H,W) uint8 tensors instead
  of a per-frame Image.fromarray PIL roundtrip, and run resize/normalize/patchify
  on config.device (GPU) when available. Bit-identical on CPU when no resize is
  configured; with a resize only the PIL->torchvision bicubic backend differs
  (<2/255 per pixel). The use_albumentations path stays PIL/cv2; reload on a box
  without the saved device falls back to CPU.

- Default image_target_size/crop to the N1.7 backbone's training geometry
  (256x256 / 230x230) when a checkpoint ships no image sizing (checkpoint_assets
  is None, e.g. finetuning nvidia/GR00T-N1.7-3B via repo-id with a new
  embodiment). Previously image_target_size=None disabled the resize, so
  full-resolution frames were patchified into ~4.7x more vision tokens than the
  model was trained on -- inflating dataloading_s (patchify) and update_s (VLM
  sequence) and skewing the input distribution. Checkpoints that pin their own
  sizing are honored; the default constants are shared with GR00T_N1_7_DEFAULTS.

Net: preprocessing leaves the CPU critical path and the VLM sees the resolution
it was trained on -- faster training/inference and a correct train/serve
distribution. Affects inference too (shared preprocessor); existing checkpoints
still load (backward compatible) but must be retrained to gain the benefits.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 11:11:34 +02:00
Steven Palma 559cba212d Merge commit 'refs/groot/docs'; commit 'refs/groot/backbone'; commit 'refs/groot/core' into fix/groot_training_experiment 2026-06-13 19:59:57 +02:00
Steven Palma 378897800a fix(groot): skip normalization overrides for training 2026-06-13 19:51:29 +02:00
Steven Palma fcb371eddd fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper).

Config:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era
  values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning
  instead of silently.
- action_decode_transform gains an 'auto' sentinel so an explicit 'none'
  opt-out wins over the libero_sim default and survives save/load round-trips.
- action_delta_indices is cached on the inputs that determine it.
- Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/
  architectures/eagle backbone markers) are rejected with a single clear
  error pointing to lerobot==0.5.1.

Processor:
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync
  select_action (relative eef/non-eef decode in eval/record flows).
- Postprocessor falls back to dataset stats when a raw checkpoint lacks the
  configured embodiment tag; raw-state cache is per-instance, not
  process-global; caller overrides (device, rename_map) are honored on the
  raw-checkpoint branch.
- Camera/modality-key mismatches warn (including the zero-match fallback);
  deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor;
  removed N1.5 processor steps are stubbed to raise the removal guidance and
  the action-unpack step is re-registered as _v2.

Model:
- Flash-attention probe is diagnostic-only; forward raises on a missing loss;
  print() replaced with logging; N1.5 base-path mismatch includes the
  removal guidance.
2026-06-13 18:30:21 +02:00
Steven Palma 895eaf0d7c fix(groot): N1.7 backbone loading and DiT parameter-count logging
- select_layer default tracks the N1.7-3B checkpoint value (16); real
  checkpoint loads still override it from config.json.
- get_backbone_cls recognizes Cosmos-Reason2 / Qwen3-VL backbones by name and
  warns (instead of silently assuming) when an unrecognized backbone is loaded
  only on the strength of backbone_model_type='qwen'.
- 'revision' pins the GR00T checkpoint repo only and is no longer forwarded
  into the unrelated backbone repo load; pin the backbone via
  transformers_loading_kwargs instead.
- DiT / SelfAttentionTransformer parameter counts go through logging.debug
  instead of print().
2026-06-12 23:55:33 +02:00
Steven Palma edda8552ec docs(groot): document the N1.5 removal and the N1.7 parity test
- groot.mdx: breaking-change warning and migration path (pin lerobot==0.5.1 to
  keep N1.5, or move to N1.7); the dead `huggingface-cli download` is replaced
  with `hf download`.
- policy_groot_README.md: N1.5 removal note, updated paper / model-card links,
  and the two-comparison (model parity + preprocessor parity) description of
  the original-vs-LeRobot test, including the raw-observation artifacts and
  recorded seed.
2026-06-12 23:40:36 +02:00
9 changed files with 1094 additions and 1154 deletions
+4 -1
View File
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
hf download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
+52 -23
View File
@@ -1,6 +1,13 @@
## Research Paper
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
## Repository
@@ -31,12 +38,22 @@ Hugging Face Models:
## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
byte-identical pre-processed inputs and the same flow-matching seed. It is
parametrized over every embodiment tag present in the checkpoint.
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
over every embodiment tag present in the checkpoint:
1. **Model parity** — given byte-identical pre-processed inputs and the same
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
### Why two environments
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves the exact
collated inputs + raw `action_pred` to one `.npz` per tag.
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
seed, and asserts the outputs match.
the processor modality configs), runs the original model, and saves to one `.npz`
per tag: the raw observations (`raw::` keys), the exact collated inputs
(`in::` keys), the seed, and the raw `action_pred`.
2. **Consumer** the pytest above, run in the _LeRobot_ venv. It discovers every
`.npz`; the model-parity case replays the byte-identical collated inputs through
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
### Fairness controls
- **Same pre-processed inputs** — the original processor's `input_ids`,
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
- **Same flow-matching seed** — fixed right before sampling on both sides; the
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
### How to run
@@ -90,15 +119,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
```
The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by
the producer; they are never committed. The test **skips** (does not fail) on CI or
The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by
the producer; they are never committed. The tests **skip** (do not fail) on CI or
when the checkpoint / artifacts are absent.
#### Env knobs (all optional)
| Var | Default | Purpose |
|---|---|---|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
| Var | Default | Purpose |
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import torch
@@ -42,6 +43,9 @@ else:
Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot")
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
logger.debug(
"Total number of DiT parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
logger.debug(
"Total number of SelfAttentionTransformer parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
+140 -37
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -23,15 +24,37 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
@@ -41,7 +64,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
"1.7": GROOT_N1_7,
}
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -52,9 +80,10 @@ def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
return normalized
@@ -286,6 +315,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try:
return normalize_groot_model_version(model_version)
except ValueError:
@@ -298,8 +329,14 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None
@@ -310,29 +347,30 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
chunk_size: int = 40
n_action_steps: int = 40
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64
max_state_dim: int = 132
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32
max_action_dim: int = 132
# Normalization (start with identity, adjust as needed)
# GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Groot-specific model parameters
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
@@ -344,11 +382,47 @@ class GrootConfig(PreTrainedConfig):
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
@@ -384,17 +458,16 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
@@ -405,6 +478,12 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
@@ -416,26 +495,48 @@ class GrootConfig(PreTrainedConfig):
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
# 'none' (normalized to None above) keeps the transform disabled.
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
# warning. The dataclass defaults are already the N1.7 values, so a plain
# GrootConfig() never triggers this.
legacy_default_remaps = (
("max_state_dim", 64, 132),
("max_action_dim", 32, 132),
("chunk_size", 50, 40),
("n_action_steps", 50, 40),
("image_size", (224, 224), (256, 256)),
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
raise ValueError(
message = (
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__()
@@ -512,7 +613,9 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
@property
+13 -9
View File
@@ -32,6 +32,7 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -71,13 +72,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 12,
"select_layer": 16,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
@@ -819,11 +820,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
def get_backbone_cls(config: GR00TN17Config):
if (
config.backbone_model_type == "qwen"
or "nvidia/Cosmos-Reason2" in config.model_name
or "Qwen/Qwen3-VL" in config.model_name
):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone
if config.backbone_model_type == "qwen":
logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
config.model_name,
)
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
@@ -909,7 +913,7 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("revision", "cache_dir", "local_files_only", "token"):
for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
+166 -50
View File
@@ -18,15 +18,12 @@
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
possible without porting their code. Dataset loading and training
orchestration are handled by LeRobot's standard training stack.
"""
import builtins
import logging
import os
from collections import deque
from pathlib import Path
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
@@ -50,9 +49,103 @@ from .configuration_groot import (
normalize_groot_model_version,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
@@ -92,8 +185,24 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
# Inference-only override for the number of flow-matching denoising steps. The action
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model
@@ -148,9 +257,10 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
pretrained_name_or_path,
)
model_id = str(pretrained_name_or_path)
@@ -181,7 +291,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -197,7 +307,7 @@ class GrootPolicy(PreTrainedPolicy):
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
logger.info("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -229,10 +339,13 @@ class GrootPolicy(PreTrainedPolicy):
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
raise ValueError(
message = (
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -258,7 +371,11 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
# An explicit config override caps the open-loop horizon (inference cadence), overriding
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
@@ -297,9 +414,7 @@ class GrootPolicy(PreTrainedPolicy):
allowed_base.add("action_mask")
return {
k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
@@ -320,9 +435,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
state = inputs.get("state")
if state is None:
@@ -331,9 +444,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
@@ -346,7 +457,9 @@ class GrootPolicy(PreTrainedPolicy):
else:
return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
model_action_horizon = int(
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
@@ -409,6 +522,11 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()}
@@ -425,6 +543,16 @@ class GrootPolicy(PreTrainedPolicy):
"""
self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
@@ -471,33 +599,21 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
"""Log Flash Attention availability (diagnostic only).
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
except Exception as e: # noqa: BLE001
logger.warning(
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
e,
)
File diff suppressed because it is too large Load Diff
@@ -207,11 +207,6 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff