Compare commits

..

8 Commits

Author SHA1 Message Date
Steven Palma 226a4c5a8c fix test relative expr 2026-06-15 21:07:01 +02:00
Steven Palma 05a9ca274b relative experiment 2026-06-15 16:38:36 +02:00
Steven Palma 13ed657056 fix(groot): GPU/tensor N1.7 image preprocessing + resize to trained resolution
GR00T training was dataloader-bound (0->100->0 GPU-utilization sawtooth).
GrootN17VLMEncodeStep ran the Qwen3-VL image processor per frame on PIL images
on the single CPU main-loop thread, and that cost is timed inside dataloading_s
(preprocessor(batch) runs in the main process, not the dataloader workers), so
adding workers cannot hide it.

- Feed the torchvision-backed Qwen3-VL processor (C,H,W) uint8 tensors instead
  of a per-frame Image.fromarray PIL roundtrip, and run resize/normalize/patchify
  on config.device (GPU) when available. Bit-identical on CPU when no resize is
  configured; with a resize only the PIL->torchvision bicubic backend differs
  (<2/255 per pixel). The use_albumentations path stays PIL/cv2; reload on a box
  without the saved device falls back to CPU.

- Default image_target_size/crop to the N1.7 backbone's training geometry
  (256x256 / 230x230) when a checkpoint ships no image sizing (checkpoint_assets
  is None, e.g. finetuning nvidia/GR00T-N1.7-3B via repo-id with a new
  embodiment). Previously image_target_size=None disabled the resize, so
  full-resolution frames were patchified into ~4.7x more vision tokens than the
  model was trained on -- inflating dataloading_s (patchify) and update_s (VLM
  sequence) and skewing the input distribution. Checkpoints that pin their own
  sizing are honored; the default constants are shared with GR00T_N1_7_DEFAULTS.

Net: preprocessing leaves the CPU critical path and the VLM sees the resolution
it was trained on -- faster training/inference and a correct train/serve
distribution. Affects inference too (shared preprocessor); existing checkpoints
still load (backward compatible) but must be retrained to gain the benefits.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 11:11:34 +02:00
Steven Palma 559cba212d Merge commit 'refs/groot/docs'; commit 'refs/groot/backbone'; commit 'refs/groot/core' into fix/groot_training_experiment 2026-06-13 19:59:57 +02:00
Steven Palma 378897800a fix(groot): skip normalization overrides for training 2026-06-13 19:51:29 +02:00
Steven Palma fcb371eddd fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper).

Config:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era
  values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning
  instead of silently.
- action_decode_transform gains an 'auto' sentinel so an explicit 'none'
  opt-out wins over the libero_sim default and survives save/load round-trips.
- action_delta_indices is cached on the inputs that determine it.
- Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/
  architectures/eagle backbone markers) are rejected with a single clear
  error pointing to lerobot==0.5.1.

Processor:
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync
  select_action (relative eef/non-eef decode in eval/record flows).
- Postprocessor falls back to dataset stats when a raw checkpoint lacks the
  configured embodiment tag; raw-state cache is per-instance, not
  process-global; caller overrides (device, rename_map) are honored on the
  raw-checkpoint branch.
- Camera/modality-key mismatches warn (including the zero-match fallback);
  deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor;
  removed N1.5 processor steps are stubbed to raise the removal guidance and
  the action-unpack step is re-registered as _v2.

Model:
- Flash-attention probe is diagnostic-only; forward raises on a missing loss;
  print() replaced with logging; N1.5 base-path mismatch includes the
  removal guidance.
2026-06-13 18:30:21 +02:00
Steven Palma 895eaf0d7c fix(groot): N1.7 backbone loading and DiT parameter-count logging
- select_layer default tracks the N1.7-3B checkpoint value (16); real
  checkpoint loads still override it from config.json.
- get_backbone_cls recognizes Cosmos-Reason2 / Qwen3-VL backbones by name and
  warns (instead of silently assuming) when an unrecognized backbone is loaded
  only on the strength of backbone_model_type='qwen'.
- 'revision' pins the GR00T checkpoint repo only and is no longer forwarded
  into the unrelated backbone repo load; pin the backbone via
  transformers_loading_kwargs instead.
- DiT / SelfAttentionTransformer parameter counts go through logging.debug
  instead of print().
2026-06-12 23:55:33 +02:00
Steven Palma edda8552ec docs(groot): document the N1.5 removal and the N1.7 parity test
- groot.mdx: breaking-change warning and migration path (pin lerobot==0.5.1 to
  keep N1.5, or move to N1.7); the dead `huggingface-cli download` is replaced
  with `hf download`.
- policy_groot_README.md: N1.5 removal note, updated paper / model-card links,
  and the two-comparison (model parity + preprocessor parity) description of
  the original-vs-LeRobot test, including the raw-observation artifacts and
  recorded seed.
2026-06-12 23:40:36 +02:00
9 changed files with 1094 additions and 1154 deletions
+4 -1
View File
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type. LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview ## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO. GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field. Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash ```bash
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \ hf download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \ --include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO --local-dir ./GR00T-N1.7-LIBERO
+52 -23
View File
@@ -1,6 +1,13 @@
## Research Paper ## Research Paper
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/ GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
## Repository ## Repository
@@ -31,12 +38,22 @@ Hugging Face Models:
## Original-vs-LeRobot parity test ## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot `tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head) reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
produces the **same raw model output** (`get_action(...)["action_pred"]`, the against NVIDIA's original `gr00t` package with two comparisons, each parametrized
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given over every embodiment tag present in the checkpoint:
byte-identical pre-processed inputs and the same flow-matching seed. It is
parametrized over every embodiment tag present in the checkpoint. 1. **Model parity** — given byte-identical pre-processed inputs and the same
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
### Why two environments ### Why two environments
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
So the test uses a **producer / consumer** split across two venvs: So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original* 1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
gr00t venv. For each embodiment it builds dummy inputs generically from the gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves the exact the processor modality configs), runs the original model, and saves to one `.npz`
collated inputs + raw `action_pred` to one `.npz` per tag. per tag: the raw observations (`raw::` keys), the exact collated inputs
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every (`in::` keys), the seed, and the raw `action_pred`.
`.npz`, replays the byte-identical inputs through the LeRobot model with the same 2. **Consumer** the pytest above, run in the _LeRobot_ venv. It discovers every
seed, and asserts the outputs match. `.npz`; the model-parity case replays the byte-identical collated inputs through
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
### Fairness controls ### Fairness controls
- **Same pre-processed inputs** — the original processor's `input_ids`, - **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are `pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization). fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The - **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.) kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed (42) right before sampling on both sides. - **Same flow-matching seed** — fixed right before sampling on both sides; the
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
### How to run ### How to run
@@ -90,15 +119,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
``` ```
The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by
the producer; they are never committed. The test **skips** (does not fail) on CI or the producer; they are never committed. The tests **skip** (do not fail) on CI or
when the checkpoint / artifacts are absent. when the checkpoint / artifacts are absent.
#### Env knobs (all optional) #### Env knobs (all optional)
| Var | Default | Purpose | | Var | Default | Purpose |
|---|---|---| | ----------------------------------------- | -------------------------------- | ------------------------------------- |
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts | | `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir | | `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` | | `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance | | `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
@@ -42,6 +43,9 @@ else:
Timesteps = None Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module): class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32): def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot") require_package("diffusers", extra="groot")
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print( logger.debug(
"Total number of DiT parameters: ", "Total number of DiT parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad), sum(p.numel() for p in self.parameters() if p.requires_grad),
) )
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers) for _ in range(self.config.num_layers)
] ]
) )
print( logger.debug(
"Total number of SelfAttentionTransformer parameters: ", "Total number of SelfAttentionTransformer parameters: %d",
sum(p.numel() for p in self.parameters() if p.requires_grad), sum(p.numel() for p in self.parameters() if p.requires_grad),
) )
+140 -37
View File
@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import json import json
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -23,15 +24,37 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7" GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is # Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version # intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise # still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch. # an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5" GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = { _GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7, "n1.7": GROOT_N1_7,
@@ -41,7 +64,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
"1.7": GROOT_N1_7, "1.7": GROOT_N1_7,
} }
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = { _GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None, "none": None,
"": None, "": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO, GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -52,9 +80,10 @@ def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower()) normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None: if normalized is None:
supported = GROOT_N1_7 supported = GROOT_N1_7
raise ValueError( message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}." if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
) message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
return normalized return normalized
@@ -286,6 +315,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
def _infer_groot_model_version_from_config(config: dict) -> str | None: def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version") model_version = config.get("model_version")
if isinstance(model_version, str): if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try: try:
return normalize_groot_model_version(model_version) return normalize_groot_model_version(model_version)
except ValueError: except ValueError:
@@ -298,8 +329,14 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_") normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}: if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7 return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL: if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7 return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None return None
@@ -310,29 +347,30 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings # Basic policy settings
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 50 chunk_size: int = 40
n_action_steps: int = 50 n_action_steps: int = 40
# Dimension settings (must match pretrained GR00T model expectations) # Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded. # Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64 max_state_dim: int = 132
# Maximum action dimension. Shorter actions will be zero-padded. # Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32 max_action_dim: int = 132
# Normalization (start with identity, adjust as needed) # GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, "VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD, "STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.IDENTITY,
} }
) )
# Image preprocessing (adjust to match Groot's expected input) # Groot-specific model parameters
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only. # Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7 model_version: str = GROOT_N1_7
@@ -344,11 +382,47 @@ class GrootConfig(PreTrainedConfig):
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step(). # Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None # 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment" embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments # Fine-tuning control arguments
# Whether to fine-tune the llm backbone # Whether to fine-tune the llm backbone
@@ -384,17 +458,16 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05 warmup_ratio: float = 0.05
use_bf16: bool = True use_bf16: bool = True
# Dataset parameters # TODO(Steven): Remove these deprecated fields in a future release.
# Video backend to use for training ('decord' or 'torchvision_av') # Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord" video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t" output_dir: str = "./tmp/gr00t"
save_steps: int = 1000 save_steps: int = 1000
@@ -405,6 +478,12 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False resume: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version) self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform) self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None: if self.base_model_path is None:
@@ -416,26 +495,48 @@ class GrootConfig(PreTrainedConfig):
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success. # 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the # This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon). # action execution horizon (see infer_groot_n1_7_action_execution_horizon).
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim": # Only the 'auto' sentinel resolves to the embodiment default; an explicit
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO # 'none' (normalized to None above) keeps the transform disabled.
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
if self.max_state_dim == 64: # GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
self.max_state_dim = 132 # stale configs) are migrated to the values the N1.7 checkpoints expect, with a
if self.max_action_dim == 32: # warning. The dataclass defaults are already the N1.7 values, so a plain
self.max_action_dim = 132 # GrootConfig() never triggers this.
if self.chunk_size == 50: legacy_default_remaps = (
self.chunk_size = 40 ("max_state_dim", 64, 132),
if self.n_action_steps == 50: ("max_action_dim", 32, 132),
self.n_action_steps = 40 ("chunk_size", 50, 40),
if tuple(self.image_size) == (224, 224): ("n_action_steps", 50, 40),
self.image_size = (256, 256) ("image_size", (224, 224), (256, 256)),
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path) inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version: if inferred_version is not None and inferred_version != self.model_version:
raise ValueError( message = (
f"GR00T model_version '{self.model_version}' does not match base_model_path " f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'." f"'{self.base_model_path}', which looks like '{inferred_version}'."
) )
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__() super().__post_init__()
@@ -512,7 +613,9 @@ class GrootConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list[int]: def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions.""" """Return indices for delta actions."""
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40 model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon))) return list(range(min(self.chunk_size, model_action_horizon)))
@property @property
+13 -9
View File
@@ -32,6 +32,7 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available: if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -71,13 +72,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048, "backbone_embedding_dim": 2048,
"tune_llm": False, "tune_llm": False,
"tune_visual": False, "tune_visual": False,
"select_layer": 12, "select_layer": 16,
"reproject_vision": False, "reproject_vision": False,
"use_flash_attention": True, "use_flash_attention": True,
"load_bf16": False, "load_bf16": False,
"backbone_trainable_params_fp32": True, "backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230), "image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
"image_target_size": (256, 256), "image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
"shortest_image_edge": None, "shortest_image_edge": None,
"crop_fraction": None, "crop_fraction": None,
"random_rotation_angle": None, "random_rotation_angle": None,
@@ -819,11 +820,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
def get_backbone_cls(config: GR00TN17Config): def get_backbone_cls(config: GR00TN17Config):
if ( if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
config.backbone_model_type == "qwen" return Qwen3Backbone
or "nvidia/Cosmos-Reason2" in config.model_name if config.backbone_model_type == "qwen":
or "Qwen/Qwen3-VL" in config.model_name logger.warning(
): "Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
config.model_name,
)
return Qwen3Backbone return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}") raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
@@ -909,7 +913,7 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True "trust_remote_code": True
} }
load_backbone_weights = kwargs.pop("load_backbone_weights", False) load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("revision", "cache_dir", "local_files_only", "token"): for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs: if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key]) transformers_loading_kwargs.setdefault(key, kwargs[key])
+166 -50
View File
@@ -18,15 +18,12 @@
Groot Policy Wrapper for LeRobot Integration Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code. possible without porting their code. Dataset loading and training
orchestration are handled by LeRobot's standard training stack.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
""" """
import builtins import builtins
import logging
import os import os
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters from ..utils import get_device_from_parameters
from .configuration_groot import ( from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7, GROOT_N1_7,
GrootConfig, GrootConfig,
infer_groot_model_version, infer_groot_model_version,
@@ -50,9 +49,103 @@ from .configuration_groot import (
normalize_groot_model_version, normalize_groot_model_version,
) )
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy") T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy): class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration.""" """Wrapper around external Groot model for LeRobot integration."""
@@ -92,8 +185,24 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True}, transformers_loading_kwargs={"trust_remote_code": True},
) )
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype # Inference-only override for the number of flow-matching denoising steps. The action
model.config.compute_dtype = model.compute_dtype # head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model return model
@@ -148,9 +257,10 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
) )
print( logger.info(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n" "The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
f"Loading pretrained model from: {pretrained_name_or_path}" requested_version,
pretrained_name_or_path,
) )
model_id = str(pretrained_name_or_path) model_id = str(pretrained_name_or_path)
@@ -181,7 +291,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint: if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading # This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...") logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained( return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path, pretrained_name_or_path=pretrained_name_or_path,
config=config, config=config,
@@ -197,7 +307,7 @@ class GrootPolicy(PreTrainedPolicy):
) )
# This is a base GR00T model - load it fresh # This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...") logger.info("Detected base GR00T model, loading from HuggingFace...")
if config is None: if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -229,10 +339,13 @@ class GrootPolicy(PreTrainedPolicy):
config.model_version = normalize_groot_model_version(config.model_version) config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path) inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version: if inferred_version is not None and inferred_version != config.model_version:
raise ValueError( message = (
f"GR00T model_version '{config.model_version}' does not match base_model_path " f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'." f"'{config.base_model_path}', which looks like '{inferred_version}'."
) )
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model # Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model() # in __init__ via _create_groot_model()
policy = cls(config) policy = cls(config)
@@ -258,7 +371,11 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon) horizons.append(checkpoint_action_horizon)
if execution_horizon is not None: if execution_horizon is not None:
horizons.append(execution_horizon) horizons.append(execution_horizon)
return min(horizons) # An explicit config override caps the open-loop horizon (inference cadence), overriding
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
def _resolve_prediction_horizon(self, actions: Tensor) -> int: def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction.""" """Return the policy-facing action horizon for a native GR00T prediction."""
@@ -297,9 +414,7 @@ class GrootPolicy(PreTrainedPolicy):
allowed_base.add("action_mask") allowed_base.add("action_mask")
return { return {
k: v k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
} }
def _prepare_n1_7_rtc_inputs( def _prepare_n1_7_rtc_inputs(
@@ -320,9 +435,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.ndim == 2: if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0) prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3: elif prev_actions.ndim != 3:
raise ValueError( raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state") state = inputs.get("state")
if state is None: if state is None:
@@ -331,9 +444,7 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.shape[0] == 1 and batch_size > 1: if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone() prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size: elif prev_actions.shape[0] != batch_size:
raise ValueError( raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero # The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every # rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
@@ -346,7 +457,9 @@ class GrootPolicy(PreTrainedPolicy):
else: else:
return inputs, None return inputs, None
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)) model_action_horizon = int(
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim)) max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon: if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :] prev_actions = prev_actions[:, -model_action_horizon:, :]
@@ -409,6 +522,11 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss' # Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss") loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()} loss_dict = {"loss": loss.item()}
@@ -425,6 +543,16 @@ class GrootPolicy(PreTrainedPolicy):
""" """
self.eval() self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch. # Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted. # During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor. # N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
@@ -471,33 +599,21 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers # Internal helpers
# ------------------------- # -------------------------
def _handle_flash_attention_compatibility(self) -> None: def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables. """Log Flash Attention availability (diagnostic only).
This addresses the common 'undefined symbol' error that occurs when Flash Attention The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
is compiled against a different PyTorch version than what's currently installed. unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
""" """
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try: try:
import flash_attn import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}") logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError as e: except ImportError:
print(f"[GROOT] Flash Attention not available: {e}") logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
print("[GROOT] Will use fallback attention mechanism") except Exception as e: # noqa: BLE001
except Exception as e: logger.warning(
if "undefined symbol" in str(e): "Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
print(f"[GROOT] Flash Attention compatibility issue detected: {e}") "an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch") e,
print("[GROOT] Consider reinstalling Flash Attention with compatible version:") )
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
File diff suppressed because it is too large Load Diff
@@ -207,11 +207,6 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad(): with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed) lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.") print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}") print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}") print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff