Compare commits

..

1 Commits

Author SHA1 Message Date
Steven Palma bf9877fa0b test(groot): regression coverage and CI guards for the N1.7 review fixes
Adds/updates unit tests for the N1.5 removal surfaces (config, checkpoint
markers, removed processor steps, v2 action-unpack registration), the
legacy-default remap warnings, action_decode_transform auto/none resolution,
2-D action decoding, the per-instance raw-state cache and pack/decode
reconnection, raw-checkpoint stats fallback and override handling, camera-match
warnings, bf16 handling, and backbone loading kwargs. Adds pytest.importorskip
guards so the fast_tests tiers pass without transformers, and asserts the
training forward pass returns a finite loss.

Note: these tests exercise symbols introduced by the GR00T N1.7 source PRs
(source-core, backbone); merge those for green CI on this branch.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-12 23:38:08 +02:00
9 changed files with 1154 additions and 1094 deletions
+1 -4
View File
@@ -4,9 +4,6 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type. LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview ## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO. GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
@@ -136,7 +133,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field. Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash ```bash
hf download nvidia/GR00T-N1.7-LIBERO \ huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \ --include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO --local-dir ./GR00T-N1.7-LIBERO
+23 -52
View File
@@ -1,13 +1,6 @@
## Research Paper ## Research Paper
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734 Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
## Repository ## Repository
@@ -38,22 +31,12 @@ Hugging Face Models:
## Original-vs-LeRobot parity test ## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot `tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head) reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
against NVIDIA's original `gr00t` package with two comparisons, each parametrized produces the **same raw model output** (`get_action(...)["action_pred"]`, the
over every embodiment tag present in the checkpoint: normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
byte-identical pre-processed inputs and the same flow-matching seed. It is
1. **Model parity** — given byte-identical pre-processed inputs and the same parametrized over every embodiment tag present in the checkpoint.
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
### Why two environments ### Why two environments
@@ -65,37 +48,25 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
So the test uses a **producer / consumer** split across two venvs: So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_ 1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
gr00t venv. For each embodiment it builds dummy inputs generically from the gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves to one `.npz` the processor modality configs), runs the original model, and saves the exact
per tag: the raw observations (`raw::` keys), the exact collated inputs collated inputs + raw `action_pred` to one `.npz` per tag.
(`in::` keys), the seed, and the raw `action_pred`. 2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
2. **Consumer** the pytest above, run in the _LeRobot_ venv. It discovers every `.npz`, replays the byte-identical inputs through the LeRobot model with the same
`.npz`; the model-parity case replays the byte-identical collated inputs through seed, and asserts the outputs match.
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
### Fairness controls ### Fairness controls
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`, - **Same pre-processed inputs** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are `pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The - **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.) kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed right before sampling on both sides; the - **Same flow-matching seed** — fixed (42) right before sampling on both sides.
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
### How to run ### How to run
@@ -119,15 +90,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
``` ```
The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by
the producer; they are never committed. The tests **skip** (do not fail) on CI or the producer; they are never committed. The test **skips** (does not fail) on CI or
when the checkpoint / artifacts are absent. when the checkpoint / artifacts are absent.
#### Env knobs (all optional) #### Env knobs (all optional)
| Var | Default | Purpose | | Var | Default | Purpose |
| ----------------------------------------- | -------------------------------- | ------------------------------------- | |---|---|---|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts | | `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir | | `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` | | `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance | | `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
@@ -43,9 +42,6 @@ else:
Timesteps = None Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module): class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32): def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot") require_package("diffusers", extra="groot")
@@ -269,8 +265,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim) self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
logger.debug( print(
"Total number of DiT parameters: %d", "Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad), sum(p.numel() for p in self.parameters() if p.requires_grad),
) )
@@ -430,8 +426,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers) for _ in range(self.config.num_layers)
] ]
) )
logger.debug( print(
"Total number of SelfAttentionTransformer parameters: %d", "Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad), sum(p.numel() for p in self.parameters() if p.requires_grad),
) )
+37 -140
View File
@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import json import json
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -24,37 +23,15 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7" GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is # Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version # intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise # still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch. # an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5" GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B" GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B" GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero" GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = { _GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7, "n1.7": GROOT_N1_7,
@@ -64,12 +41,7 @@ _GROOT_MODEL_VERSION_ALIASES = {
"1.7": GROOT_N1_7, "1.7": GROOT_N1_7,
} }
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = { _GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None, "none": None,
"": None, "": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO, GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -80,10 +52,9 @@ def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower()) normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None: if normalized is None:
supported = GROOT_N1_7 supported = GROOT_N1_7
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}." raise ValueError(
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES: f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}" )
raise ValueError(message)
return normalized return normalized
@@ -315,8 +286,6 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
def _infer_groot_model_version_from_config(config: dict) -> str | None: def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version") model_version = config.get("model_version")
if isinstance(model_version, str): if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try: try:
return normalize_groot_model_version(model_version) return normalize_groot_model_version(model_version)
except ValueError: except ValueError:
@@ -329,14 +298,8 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_") normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}: if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7 return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL: if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7 return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None return None
@@ -347,30 +310,29 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings # Basic policy settings
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 40 chunk_size: int = 50
n_action_steps: int = 40 n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations) # Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded. # Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 132 max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded. # Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 132 max_action_dim: int = 32
# GR00T normalizes state/action internally in its processor steps (min/max with # Normalization (start with identity, adjust as needed)
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY, "VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY, "STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.IDENTITY, "ACTION": NormalizationMode.MEAN_STD,
} }
) )
# Groot-specific model parameters # Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only. # Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7 model_version: str = GROOT_N1_7
@@ -382,47 +344,11 @@ class GrootConfig(PreTrainedConfig):
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step(). # Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no action_decode_transform: str | None = None
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment" embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments # Fine-tuning control arguments
# Whether to fine-tune the llm backbone # Whether to fine-tune the llm backbone
@@ -458,16 +384,17 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05 warmup_ratio: float = 0.05
use_bf16: bool = True use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release. # Dataset parameters
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation # Video backend to use for training ('decord' or 'torchvision_av')
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord" video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t" output_dir: str = "./tmp/gr00t"
save_steps: int = 1000 save_steps: int = 1000
@@ -478,12 +405,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False resume: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version) self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform) self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None: if self.base_model_path is None:
@@ -495,48 +416,26 @@ class GrootConfig(PreTrainedConfig):
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success. # 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the # This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon). # action execution horizon (see infer_groot_n1_7_action_execution_horizon).
# Only the 'auto' sentinel resolves to the embodiment default; an explicit if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
# 'none' (normalized to None above) keeps the transform disabled. self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or if self.max_state_dim == 64:
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a self.max_state_dim = 132
# warning. The dataclass defaults are already the N1.7 values, so a plain if self.max_action_dim == 32:
# GrootConfig() never triggers this. self.max_action_dim = 132
legacy_default_remaps = ( if self.chunk_size == 50:
("max_state_dim", 64, 132), self.chunk_size = 40
("max_action_dim", 32, 132), if self.n_action_steps == 50:
("chunk_size", 50, 40), self.n_action_steps = 40
("n_action_steps", 50, 40), if tuple(self.image_size) == (224, 224):
("image_size", (224, 224), (256, 256)), self.image_size = (256, 256)
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path) inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version: if inferred_version is not None and inferred_version != self.model_version:
message = ( raise ValueError(
f"GR00T model_version '{self.model_version}' does not match base_model_path " f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'." f"'{self.base_model_path}', which looks like '{inferred_version}'."
) )
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__() super().__post_init__()
@@ -613,9 +512,7 @@ class GrootConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list[int]: def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions.""" """Return indices for delta actions."""
model_action_horizon = ( model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon))) return list(range(min(self.chunk_size, model_action_horizon)))
@property @property
+9 -13
View File
@@ -32,7 +32,6 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available: if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -72,13 +71,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048, "backbone_embedding_dim": 2048,
"tune_llm": False, "tune_llm": False,
"tune_visual": False, "tune_visual": False,
"select_layer": 16, "select_layer": 12,
"reproject_vision": False, "reproject_vision": False,
"use_flash_attention": True, "use_flash_attention": True,
"load_bf16": False, "load_bf16": False,
"backbone_trainable_params_fp32": True, "backbone_trainable_params_fp32": True,
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE, "image_crop_size": (230, 230),
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE, "image_target_size": (256, 256),
"shortest_image_edge": None, "shortest_image_edge": None,
"crop_fraction": None, "crop_fraction": None,
"random_rotation_angle": None, "random_rotation_angle": None,
@@ -820,14 +819,11 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
def get_backbone_cls(config: GR00TN17Config): def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name: if (
return Qwen3Backbone config.backbone_model_type == "qwen"
if config.backbone_model_type == "qwen": or "nvidia/Cosmos-Reason2" in config.model_name
logger.warning( or "Qwen/Qwen3-VL" in config.model_name
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible " ):
"backbone because backbone_model_type='qwen'.",
config.model_name,
)
return Qwen3Backbone return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}") raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
@@ -913,7 +909,7 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True "trust_remote_code": True
} }
load_backbone_weights = kwargs.pop("load_backbone_weights", False) load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("cache_dir", "local_files_only", "token"): for key in ("revision", "cache_dir", "local_files_only", "token"):
if key in kwargs: if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key]) transformers_loading_kwargs.setdefault(key, kwargs[key])
+50 -166
View File
@@ -18,12 +18,15 @@
Groot Policy Wrapper for LeRobot Integration Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code. Dataset loading and training possible without porting their code.
orchestration are handled by LeRobot's standard training stack.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
""" """
import builtins import builtins
import logging
import os import os
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@@ -39,8 +42,6 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters from ..utils import get_device_from_parameters
from .configuration_groot import ( from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7, GROOT_N1_7,
GrootConfig, GrootConfig,
infer_groot_model_version, infer_groot_model_version,
@@ -49,103 +50,9 @@ from .configuration_groot import (
normalize_groot_model_version, normalize_groot_model_version,
) )
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy") T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy): class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration.""" """Wrapper around external Groot model for LeRobot integration."""
@@ -185,24 +92,8 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True}, transformers_loading_kwargs={"trust_remote_code": True},
) )
# Inference-only override for the number of flow-matching denoising steps. The action model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the model.config.compute_dtype = model.compute_dtype
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model return model
@@ -257,10 +148,9 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
) )
logger.info( print(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s", f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
requested_version, f"Loading pretrained model from: {pretrained_name_or_path}"
pretrained_name_or_path,
) )
model_id = str(pretrained_name_or_path) model_id = str(pretrained_name_or_path)
@@ -291,7 +181,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint: if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading # This is a fine-tuned LeRobot checkpoint - use parent class loading
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...") print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained( return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path, pretrained_name_or_path=pretrained_name_or_path,
config=config, config=config,
@@ -307,7 +197,7 @@ class GrootPolicy(PreTrainedPolicy):
) )
# This is a base GR00T model - load it fresh # This is a base GR00T model - load it fresh
logger.info("Detected base GR00T model, loading from HuggingFace...") print("Detected base GR00T model, loading from HuggingFace...")
if config is None: if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7 model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -339,13 +229,10 @@ class GrootPolicy(PreTrainedPolicy):
config.model_version = normalize_groot_model_version(config.model_version) config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path) inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version: if inferred_version is not None and inferred_version != config.model_version:
message = ( raise ValueError(
f"GR00T model_version '{config.model_version}' does not match base_model_path " f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'." f"'{config.base_model_path}', which looks like '{inferred_version}'."
) )
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model # Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model() # in __init__ via _create_groot_model()
policy = cls(config) policy = cls(config)
@@ -371,11 +258,7 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon) horizons.append(checkpoint_action_horizon)
if execution_horizon is not None: if execution_horizon is not None:
horizons.append(execution_horizon) horizons.append(execution_horizon)
# An explicit config override caps the open-loop horizon (inference cadence), overriding return min(horizons)
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
def _resolve_prediction_horizon(self, actions: Tensor) -> int: def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction.""" """Return the policy-facing action horizon for a native GR00T prediction."""
@@ -414,7 +297,9 @@ class GrootPolicy(PreTrainedPolicy):
allowed_base.add("action_mask") allowed_base.add("action_mask")
return { return {
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info") k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
} }
def _prepare_n1_7_rtc_inputs( def _prepare_n1_7_rtc_inputs(
@@ -435,7 +320,9 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.ndim == 2: if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0) prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3: elif prev_actions.ndim != 3:
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.") raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state") state = inputs.get("state")
if state is None: if state is None:
@@ -444,7 +331,9 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.shape[0] == 1 and batch_size > 1: if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone() prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size: elif prev_actions.shape[0] != batch_size:
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.") raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero # The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every # rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
@@ -457,9 +346,7 @@ class GrootPolicy(PreTrainedPolicy):
else: else:
return inputs, None return inputs, None
model_action_horizon = int( model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim)) max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon: if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :] prev_actions = prev_actions[:, -model_action_horizon:, :]
@@ -522,11 +409,6 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss' # Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss") loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()} loss_dict = {"loss": loss.item()}
@@ -543,16 +425,6 @@ class GrootPolicy(PreTrainedPolicy):
""" """
self.eval() self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch. # Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted. # During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor. # N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
@@ -599,21 +471,33 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers # Internal helpers
# ------------------------- # -------------------------
def _handle_flash_attention_compatibility(self) -> None: def _handle_flash_attention_compatibility(self) -> None:
"""Log Flash Attention availability (diagnostic only). """Handle Flash Attention compatibility issues by setting environment variables.
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is This addresses the common 'undefined symbol' error that occurs when Flash Attention
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not is compiled against a different PyTorch version than what's currently installed.
change behaviour or mutate global state.
""" """
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try: try:
import flash_attn import flash_attn
logger.debug("Flash Attention %s is available.", flash_attn.__version__) print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError: except ImportError as e:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.") print(f"[GROOT] Flash Attention not available: {e}")
except Exception as e: # noqa: BLE001 print("[GROOT] Will use fallback attention mechanism")
logger.warning( except Exception as e:
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is " if "undefined symbol" in str(e):
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.", print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
e, print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
) print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
File diff suppressed because it is too large Load Diff
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad(): with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed) lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.") print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}") print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}") print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff