Compare commits

..

1 Commits

Author SHA1 Message Date
Steven Palma 895eaf0d7c fix(groot): N1.7 backbone loading and DiT parameter-count logging
- select_layer default tracks the N1.7-3B checkpoint value (16); real
  checkpoint loads still override it from config.json.
- get_backbone_cls recognizes Cosmos-Reason2 / Qwen3-VL backbones by name and
  warns (instead of silently assuming) when an unrecognized backbone is loaded
  only on the strength of backbone_model_type='qwen'.
- 'revision' pins the GR00T checkpoint repo only and is no longer forwarded
  into the unrelated backbone repo load; pin the backbone via
  transformers_loading_kwargs instead.
- DiT / SelfAttentionTransformer parameter counts go through logging.debug
  instead of print().
2026-06-12 23:55:33 +02:00
10 changed files with 225 additions and 1664 deletions
+1 -4
View File
@@ -4,9 +4,6 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
@@ -136,7 +133,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
hf download nvidia/GR00T-N1.7-LIBERO \
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
+23 -52
View File
@@ -1,13 +1,6 @@
## Research Paper
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
@@ -38,22 +31,12 @@ Hugging Face Models:
## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
over every embodiment tag present in the checkpoint:
1. **Model parity** — given byte-identical pre-processed inputs and the same
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
byte-identical pre-processed inputs and the same flow-matching seed. It is
parametrized over every embodiment tag present in the checkpoint.
### Why two environments
@@ -65,37 +48,25 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves to one `.npz`
per tag: the raw observations (`raw::` keys), the exact collated inputs
(`in::` keys), the seed, and the raw `action_pred`.
2. **Consumer** the pytest above, run in the _LeRobot_ venv. It discovers every
`.npz`; the model-parity case replays the byte-identical collated inputs through
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
the processor modality configs), runs the original model, and saves the exact
collated inputs + raw `action_pred` to one `.npz` per tag.
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
seed, and asserts the outputs match.
### Fairness controls
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
- **Same pre-processed inputs** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed right before sampling on both sides; the
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
### How to run
@@ -119,15 +90,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
```
The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by
the producer; they are never committed. The tests **skip** (do not fail) on CI or
The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by
the producer; they are never committed. The test **skips** (does not fail) on CI or
when the checkpoint / artifacts are absent.
#### Env knobs (all optional)
| Var | Default | Purpose |
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
| Var | Default | Purpose |
|---|---|---|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
+37 -121
View File
@@ -15,7 +15,6 @@
# limitations under the License.
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -24,29 +23,15 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7"
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
# still rejects it). It is retained only so that infer_groot_model_version can recognise
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
GROOT_N1_5 = "n1.5"
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
GROOT_N1_5_REMOVAL_GUIDANCE = (
"GR00T N1.5 support was removed from LeRobot. "
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.7": GROOT_N1_7,
@@ -56,12 +41,7 @@ _GROOT_MODEL_VERSION_ALIASES = {
"1.7": GROOT_N1_7,
}
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
@@ -72,10 +52,9 @@ def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = GROOT_N1_7
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
return normalized
@@ -307,8 +286,6 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
return GROOT_N1_5
try:
return normalize_groot_model_version(model_version)
except ValueError:
@@ -321,17 +298,8 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
backbone_cfg = config.get("backbone_cfg")
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
return GROOT_N1_5
return None
@@ -342,32 +310,27 @@ class GrootConfig(PreTrainedConfig):
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 40
n_action_steps: int = 40
chunk_size: int = 50
n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 132
max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 132
max_action_dim: int = 32
# GR00T normalizes state/action internally in its processor steps (min/max with
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
# handles image normalization. The policy therefore does NOT use LeRobot's
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
# Normalization (start with identity, adjust as needed)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
@@ -381,14 +344,7 @@ class GrootConfig(PreTrainedConfig):
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
action_decode_transform: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -428,13 +384,17 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
@@ -445,15 +405,6 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
@@ -465,48 +416,26 @@ class GrootConfig(PreTrainedConfig):
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
# This matches the embodiment-specific handling already done for the
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
# 'none' (normalized to None above) keeps the transform disabled.
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
self.action_decode_transform = (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
)
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
# warning. The dataclass defaults are already the N1.7 values, so a plain
# GrootConfig() never triggers this.
legacy_default_remaps = (
("max_state_dim", 64, 132),
("max_action_dim", 32, 132),
("chunk_size", 50, 40),
("n_action_steps", 50, 40),
("image_size", (224, 224), (256, 256)),
)
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
current_value = getattr(self, field_name)
if isinstance(legacy_value, tuple):
current_value = tuple(current_value)
if current_value == legacy_value:
logger.warning(
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
"if this is not what you want.",
field_name,
legacy_value,
n1_7_value,
)
setattr(self, field_name, n1_7_value)
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
message = (
raise ValueError(
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
super().__post_init__()
@@ -582,22 +511,9 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
"""Return indices for delta actions."""
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def reward_delta_indices(self) -> None:
+1 -7
View File
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
"select_layer": 16,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
@@ -822,8 +822,6 @@ def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone
if config.backbone_model_type == "qwen":
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
# marker, so trust the explicit backbone type but surface what is being assumed.
logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
@@ -914,10 +912,6 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
for key in ("cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
+49 -46
View File
@@ -18,12 +18,15 @@
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T N1.7 components where
possible without porting their code. Dataset loading and training
orchestration are handled by LeRobot's standard training stack.
possible without porting their code.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import builtins
import logging
import os
from collections import deque
from pathlib import Path
@@ -39,8 +42,6 @@ from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
@@ -49,8 +50,6 @@ from .configuration_groot import (
normalize_groot_model_version,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
@@ -93,11 +92,8 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
# GR00TN17 defines no compute_dtype attribute, so only record the
# bf16 preference when it is enabled instead of reading a default back.
if self.config.use_bf16:
model.compute_dtype = "bfloat16"
model.config.compute_dtype = "bfloat16"
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
@@ -152,10 +148,9 @@ class GrootPolicy(PreTrainedPolicy):
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
pretrained_name_or_path,
print(
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
@@ -186,7 +181,7 @@ class GrootPolicy(PreTrainedPolicy):
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
@@ -202,7 +197,7 @@ class GrootPolicy(PreTrainedPolicy):
)
# This is a base GR00T model - load it fresh
logger.info("Detected base GR00T model, loading from HuggingFace...")
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
@@ -234,13 +229,10 @@ class GrootPolicy(PreTrainedPolicy):
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
message = (
raise ValueError(
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
raise ValueError(message)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -305,7 +297,9 @@ class GrootPolicy(PreTrainedPolicy):
allowed_base.add("action_mask")
return {
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
k: v
for k, v in batch.items()
if k in allowed_base and not (k.startswith("next.") or k == "info")
}
def _prepare_n1_7_rtc_inputs(
@@ -326,7 +320,9 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.ndim == 2:
prev_actions = prev_actions.unsqueeze(0)
elif prev_actions.ndim != 3:
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
raise ValueError(
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
)
state = inputs.get("state")
if state is None:
@@ -335,7 +331,9 @@ class GrootPolicy(PreTrainedPolicy):
if prev_actions.shape[0] == 1 and batch_size > 1:
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
elif prev_actions.shape[0] != batch_size:
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
raise ValueError(
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
)
# The generic LeRobot RTC engine pads short leftovers with exact zero
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
@@ -348,9 +346,7 @@ class GrootPolicy(PreTrainedPolicy):
else:
return inputs, None
model_action_horizon = int(
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
)
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
if prev_actions.shape[1] > model_action_horizon:
prev_actions = prev_actions[:, -model_action_horizon:, :]
@@ -413,11 +409,6 @@ class GrootPolicy(PreTrainedPolicy):
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
if loss is None:
raise RuntimeError(
"GR00T model.forward did not return a 'loss'. Training batches must include "
"'action' and 'action_mask'; check the preprocessor output."
)
loss_dict = {"loss": loss.item()}
@@ -480,21 +471,33 @@ class GrootPolicy(PreTrainedPolicy):
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Log Flash Attention availability (diagnostic only).
"""Handle Flash Attention compatibility issues by setting environment variables.
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
except Exception as e: # noqa: BLE001
logger.warning(
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
e,
)
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")
+38 -297
View File
@@ -15,16 +15,14 @@
# limitations under the License.
import json
import logging
from copy import copy
from dataclasses import dataclass, field, fields, is_dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
@@ -61,7 +59,6 @@ from lerobot.utils.constants import (
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7_BACKBONE_MODEL,
GrootConfig,
is_raw_groot_n1_7_checkpoint,
@@ -83,6 +80,12 @@ N1_7_EMBODIMENT_MAPPING = {
"new_embodiment": 10,
}
_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {}
def _n1_7_state_cache_key(value: str | None) -> str:
return value or "groot_n1_7_default"
@dataclass
class _GrootN17CheckpointProcessorAssets:
@@ -468,98 +471,25 @@ def _legacy_groot_processor_overrides(
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
"""
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
path = Path(pretrained_path).expanduser()
if path.is_dir():
config = _read_json(path / config_filename)
elif path.exists():
if not path.is_dir():
return False
else:
try:
config_path = hf_hub_download(
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
)
except Exception:
return False
config = _read_json(Path(config_path))
config = _read_json(path / config_filename)
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def _apply_groot_step_overrides(
pipeline: PolicyProcessorPipeline,
overrides: dict[str, Any] | None,
) -> None:
"""Apply ``from_pretrained``-style step overrides to a freshly built pipeline.
Raw N1.7 checkpoints build their processors from scratch instead of
deserializing them, so caller overrides must be applied to the constructed
steps. Override keys match a step's registry name or, as a convenience, its
class name (``PolicyProcessorPipeline.from_pretrained`` matches registered
steps by registry name only prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently.
"""
if not overrides:
return
def _step_keys(step: ProcessorStep) -> set[str]:
keys = {type(step).__name__}
registry_name = getattr(type(step), "_registry_name", None)
if registry_name:
keys.add(registry_name)
return keys
for override_key, step_overrides in overrides.items():
matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)]
if not matched_steps:
available = [
getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps
]
raise KeyError(
f"Override key '{override_key}' does not match any step of the GR00T processor pipeline "
f"built for this raw N1.7 checkpoint. Available step keys: {available}."
)
for step in matched_steps:
if not is_dataclass(step):
raise TypeError(
f"Cannot apply overrides to step '{override_key}': it is not a dataclass step."
)
init_field_names = {f.name for f in fields(step) if f.init}
for field_name, value in dict(step_overrides).items():
if field_name not in init_field_names:
raise TypeError(
f"Override field '{field_name}' is not a config field of step '{override_key}'. "
f"Available fields: {sorted(init_field_names)}."
)
setattr(step, field_name, value)
# Re-derive attributes computed from the overridden config (e.g.
# DeviceProcessorStep resolves its torch.device in __post_init__).
post_init = getattr(step, "__post_init__", None)
if callable(post_init):
post_init()
def make_groot_pre_post_processors_from_pretrained(
config: GrootConfig,
pretrained_path: str,
@@ -578,20 +508,12 @@ def make_groot_pre_post_processors_from_pretrained(
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
processor_cfg.base_model_path = str(pretrained_path)
preprocessor, postprocessor = make_groot_pre_post_processors(
return make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=dataset_stats,
)
# Raw checkpoints have no serialized pipelines to load overrides into,
# so apply the caller overrides (e.g. device and rename_map from
# lerobot-eval or the policy server) to the freshly built steps.
_apply_groot_step_overrides(preprocessor, preprocessor_overrides)
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
if _local_processor_config_has_step(
pretrained_path,
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
@@ -599,55 +521,15 @@ def make_groot_pre_post_processors_from_pretrained(
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
def _load_groot_processor_pipelines(
pretrained_path: str,
*,
preprocessor_overrides: dict[str, Any],
postprocessor_overrides: dict[str, Any],
preprocessor_config_filename: str,
postprocessor_config_filename: str,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
preprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=preprocessor_config_filename,
@@ -662,6 +544,7 @@ def _load_groot_processor_pipelines(
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -681,28 +564,6 @@ def _reconnect_groot_relative_absolute_steps(
step.relative_step = relative_step
def _reconnect_groot_n1_7_pack_decode_steps(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> None:
"""Re-link a deserialized N1.7 action decode step to its pack step.
The pack step holds the per-instance raw-state cache that relative-action
decoding reads its reference state from; the link itself is not serialized.
"""
pack_step = next(
(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)),
None,
)
if pack_step is None:
return
for step in postprocessor.steps:
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
step.pack_step = pack_step
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
@@ -746,12 +607,9 @@ def make_groot_pre_post_processors(
else action_horizon
)
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
embodiment_mapping = (
checkpoint_assets.embodiment_mapping
if checkpoint_assets is not None
else dict(N1_7_EMBODIMENT_MAPPING)
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
)
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
@@ -760,6 +618,7 @@ def make_groot_pre_post_processors(
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
pack_step = GrootN17PackInputsStep(
state_horizon=1,
action_horizon=action_horizon,
@@ -777,6 +636,7 @@ def make_groot_pre_post_processors(
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
state_cache_key=state_cache_key,
)
input_steps: list[ProcessorStep] = [
@@ -787,31 +647,14 @@ def make_groot_pre_post_processors(
model_name=config.n1_7_backbone_model,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge
if checkpoint_assets is not None
else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
),
DeviceProcessorStep(device=config.device),
]
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
raise ValueError(
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
"actions cannot be normalized or decoded. Pass dataset_stats, or set "
"config.embodiment_tag to an embodiment present in the checkpoint's statistics.json."
)
if checkpoint_assets is None or not checkpoint_has_stats:
# When the checkpoint sidecars have no stats for the configured
# embodiment tag (e.g. finetuning a raw base checkpoint with the
# default 'new_embodiment' tag), the pack step above normalized with
# the dataset stats; the decode step must invert with the same stats
# instead of using a checkpoint decoder whose empty stats would
# silently return normalized [-1, 1] actions.
if checkpoint_assets is None:
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
@@ -826,6 +669,7 @@ def make_groot_pre_post_processors(
use_percentiles=checkpoint_assets.use_percentiles,
use_relative_action=checkpoint_assets.use_relative_action,
pack_step=pack_step,
state_cache_key=state_cache_key,
action_decode_transform=config.action_decode_transform,
)
@@ -926,7 +770,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
try:
from transformers import (
AutoTokenizer,
Qwen2VLImageProcessor,
Qwen2VLImageProcessorFast,
Qwen3VLProcessor,
Qwen3VLVideoProcessor,
)
@@ -937,7 +781,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
) from exc
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True)
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
proc = Qwen3VLProcessor(
image_processor=image_processor,
@@ -1058,11 +902,8 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Unused: kept so serialized configs that include it still load. The raw
# state cache is per instance (_last_raw_state), never process-global.
state_cache_key: str = ""
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
available = {key for key in obs if key.startswith(OBS_IMAGES)}
@@ -1072,56 +913,19 @@ class GrootN17PackInputsStep(ProcessorStep):
return sorted(available)
ordered: list[str] = []
unmatched: list[str] = []
for modality_key in self.video_modality_keys:
candidates = [f"{OBS_IMAGES}.{modality_key}"]
# Alias for datasets converted with generic camera names (e.g. the
# LIBERO conversions expose the wrist camera as
# `observation.images.image2`), so raw N1.7 LIBERO checkpoints
# match those datasets out of the box.
if modality_key == "wrist_image":
candidates.append(f"{OBS_IMAGES}.image2")
elif modality_key == "image":
candidates.append(f"{OBS_IMAGES}.image")
match = next((candidate for candidate in candidates if candidate in available), None)
if match is None:
unmatched.append(modality_key)
else:
if match is not None:
ordered.append(match)
if not ordered:
if not self._warned_image_keys:
self._warned_image_keys = True
logging.warning(
"None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; "
"falling back to feeding all cameras in alphabetical order, which is unlikely to be "
"the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via "
"--rename_map) to match the checkpoint keys.",
self.video_modality_keys,
sorted(available),
)
return sorted(available)
unused = sorted(available - set(ordered))
if (unmatched or unused) and not self._warned_image_keys:
self._warned_image_keys = True
if unmatched:
logging.warning(
"GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; "
"the model will receive %d view(s) instead of the %d it was trained with. Rename "
"the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.",
unmatched,
sorted(available),
len(ordered),
len(self.video_modality_keys),
self.video_modality_keys,
)
if unused:
logging.warning(
"Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality "
"keys %s, which matched %s.",
unused,
self.video_modality_keys,
ordered,
)
return ordered
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -1184,6 +988,7 @@ class GrootN17PackInputsStep(ProcessorStep):
start_idx += dim
if grouped:
self._last_raw_state = grouped
_N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped
img_keys = self._ordered_image_keys(obs)
if img_keys:
@@ -1296,6 +1101,7 @@ class GrootN17PackInputsStep(ProcessorStep):
"video_modality_keys": self.video_modality_keys,
"raw_stats": self.raw_stats,
"modality_config": self.modality_config,
"state_cache_key": self.state_cache_key,
}
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
@@ -1417,7 +1223,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"use_albumentations": self.use_albumentations,
}
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
for stat_name in ("mean", "q01", "min", "max", "std"):
value = entry.get(stat_name)
@@ -1546,18 +1351,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
group with the checkpoint stats, converts relative groups to absolute values
using the raw state cached during packing, concatenates groups in checkpoint
order, and finally slices to the environment action dimension.
Relative-action decoding reads the reference state from the connected
``pack_step`` (re-linked after ``from_pretrained`` by
``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the
most recent preprocess call. Engines that decode the whole chunk right
after prediction (RTC, async policy server) therefore use the
prediction-time state, matching Isaac-GR00T. The sync per-step queue path
instead decodes each popped (B, D) action against the latest observation:
the reference can be newer than the observation the chunk was predicted
from, and per-timestep relative stats are applied as if the popped action
were chunk step 0. Fixing that would require carrying the reference state
and chunk index alongside each queued action through the postprocessor.
"""
env_action_dim: int = 0
@@ -1565,7 +1358,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None
use_percentiles: bool = False
use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1586,12 +1378,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
return transition
action_np = action.detach().cpu().float().numpy()
# The sync action queue postprocesses popped actions as (B, D); decode
# them as single-step (B, 1, D) chunks and squeeze the horizon back at
# the end so both ranks share the chunk decode logic below.
squeeze_horizon = action_np.ndim == 2
if squeeze_horizon:
action_np = action_np[:, None, :]
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
if valid_horizon is not None:
action_np = action_np[:, :valid_horizon]
@@ -1619,24 +1405,17 @@ class GrootN17ActionDecodeStep(ProcessorStep):
use_relative_action=self.use_relative_action,
use_percentiles=self.use_percentiles,
)
# Per-timestep stats carry one row per chunk step; align them with
# the decoded horizon (chunks always start at step 0, and a popped
# (B, D) action is decoded as step 0).
if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]:
min_v = min_v[: normalized.shape[1]]
max_v = max_v[: normalized.shape[1]]
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
start_idx += dim
if self.use_relative_action:
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
if raw_state is None:
raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key))
if raw_state is None:
raise RuntimeError(
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
"GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. "
"Build both pipelines through make_groot_pre_post_processors (or load them together "
"via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an "
"observation before decoding actions."
"GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep "
"to convert relative N1.7 actions back to absolute actions."
)
for idx, key in enumerate(action_keys):
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
@@ -1672,8 +1451,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
action_keys=action_keys,
decoded_groups=decoded_groups,
)
if squeeze_horizon:
decoded = decoded[:, 0]
new_transition = transition.copy()
new_transition[TransitionKey.ACTION] = torch.as_tensor(
decoded, dtype=action.dtype, device=action.device
@@ -1690,15 +1467,13 @@ class GrootN17ActionDecodeStep(ProcessorStep):
"modality_config": self.modality_config,
"use_percentiles": self.use_percentiles,
"use_relative_action": self.use_relative_action,
"state_cache_key": self.state_cache_key,
"action_decode_transform": self.action_decode_transform,
}
@dataclass
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance).
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
# Apply inverse of min-max normalization if it was used in preprocessor
@@ -1810,37 +1585,3 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
if reconstructed:
self.stats = reconstructed
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
Saved N1.5 checkpoints reference these registry names in their processor JSON
files. Deserializing them must fail with the canonical N1.5 removal guidance
instead of an opaque registry KeyError (or, for
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
action-chunk semantics changed).
"""
@ProcessorStepRegistry.register(name=registry_name)
class _RemovedGrootN15ProcessorStep(ProcessorStep):
def __init__(self, **_kwargs: Any) -> None:
raise ValueError(
f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. "
f"{GROOT_N1_5_REMOVAL_GUIDANCE}"
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
raise NotImplementedError
def transform_features(self, features):
raise NotImplementedError
for _removed_n1_5_step_name in (
"groot_pack_inputs_v3",
"groot_eagle_encode_v3",
"groot_eagle_collate_v3",
"groot_action_unpack_unnormalize_v1",
):
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)
@@ -207,11 +207,6 @@ def test_lerobot_groot_forward_pass():
with torch.no_grad():
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
assert isinstance(lerobot_loss, torch.Tensor)
assert torch.isfinite(lerobot_loss).all()
assert "loss" in lerobot_metrics
assert np.isfinite(lerobot_metrics["loss"])
print("\nForward pass successful.")
print(f" - Loss: {lerobot_loss.item():.6f}")
print(f" - Metrics: {lerobot_metrics}")
File diff suppressed because it is too large Load Diff
+26 -175
View File
@@ -14,36 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
once in the original ``gr00t`` env by the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
normalized flow-matching prediction before any action decoding) as NVIDIA's original
``gr00t`` package, given byte-identical pre-processed inputs and the same
flow-matching seed. The comparison is parametrized over every embodiment tag present
in the checkpoint.
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
action head + Qwen3-VL backbone must produce the SAME raw model output
(``action_pred``, the normalized flow-matching prediction before any action
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
pre-processed inputs and the flow-matching seed recorded in the artifact.
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
template / tokenizer / image packing + state normalization, no mocks) must produce
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
as the original package's processor, given the identical raw observations
(images, state, language) recorded in the artifact. Artifacts written by older
versions of the dump script carry no raw observations; this case then SKIPS with
a regeneration hint.
To keep the comparison fair, the original outputs + the exact collated inputs are
produced once per embodiment in the original ``gr00t`` env via the companion script
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
to per-tag ``.npz`` files.
This test discovers those artifacts, replays the identical inputs through the LeRobot
model, and compares.
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default they look for artifacts in
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
present, or when no artifact has been generated. By default it looks for artifacts in
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
for the full run procedure.
"""
import os
import warnings
from pathlib import Path
from typing import Any
import numpy as np
import pytest
@@ -55,9 +50,7 @@ pytestmark = pytest.mark.skipif(
)
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
SEED = 42
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
@@ -67,11 +60,6 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
_ARTIFACT_PREFIX = "original_n1_7_"
_ARTIFACT_SUFFIX = ".npz"
# Collated keys compared by the preprocessor parity case: integer/id tensors must
# match exactly; float tensors within ATOL/RTOL.
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
def _artifact_dir() -> Path:
"""Directory holding the per-embodiment .npz artifacts.
@@ -121,20 +109,9 @@ def _resolve_checkpoint() -> str:
return str(ckpt)
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
def _load_artifact(path: Path):
data = np.load(path, allow_pickle=True)
original_action = torch.from_numpy(data["action_pred"]).float()
if "seed" in data.files:
seed = int(data["seed"])
else:
warnings.warn(
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
"regenerate the artifact with the current dump script.",
stacklevel=2,
)
seed = SEED
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
inputs = {}
for key in data.files:
@@ -147,45 +124,7 @@ def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], i
if "int" in declared or "long" in declared:
t = t.long()
inputs[name] = t
return original_action, inputs, seed
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
"""Return the raw observation recorded in the artifact, or None for old artifacts.
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
exact raw observation the producer fed to the original processor: per-camera uint8
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
(``raw::state.<key>``, (B, T, dim)) and the language instruction
(``raw::language``, one string per batch element). ``raw_video_keys`` /
``raw_state_keys`` record the checkpoint modality-key order.
"""
data = np.load(path, allow_pickle=True)
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
if any(marker not in data.files for marker in markers):
return None
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
return {
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
"language": [str(t) for t in data["raw::language"].tolist()],
}
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
"""Convert the producer's raw observation into a LeRobot policy batch."""
batch: dict[str, Any] = {}
for key, frames in raw["video"].items():
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
# observation.state is the per-key state vectors (latest frame) concatenated in
# checkpoint modality-key order -- the layout the LeRobot pack step and the
# flattened checkpoint statistics expect.
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
batch["task"] = list(raw["language"])
return batch
return original_action, inputs
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
@@ -200,36 +139,6 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
return nested.get("inputs", nested)
def _assert_collated_parity(
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
) -> None:
"""Compare one collated tensor produced by LeRobot against the original's."""
assert isinstance(lerobot_value, torch.Tensor), (
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
f"{type(lerobot_value).__name__}, expected a tensor."
)
lerobot_t = lerobot_value.detach().cpu()
original_t = original_value.detach().cpu()
assert lerobot_t.shape == original_t.shape, (
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
f"original={tuple(original_t.shape)}."
)
if exact:
mismatched = int((lerobot_t.long() != original_t.long()).sum())
assert mismatched == 0, (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
f"{mismatched}/{original_t.numel()} elements mismatch."
)
else:
lerobot_f, original_f = lerobot_t.float(), original_t.float()
max_diff = (lerobot_f - original_f).abs().max().item()
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
)
@pytest.fixture(scope="module")
def lerobot_model():
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
@@ -256,7 +165,8 @@ def lerobot_model():
_ARTIFACTS = _discover_artifacts()
_requires_artifacts = pytest.mark.skipif(
@pytest.mark.skipif(
not _ARTIFACTS,
reason=(
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
@@ -264,30 +174,24 @@ _requires_artifacts = pytest.mark.skipif(
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
),
)
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
original_action, flat_inputs, seed = _load_artifact(artifact)
original_action, flat_inputs = _load_artifact(artifact)
model_inputs = _unflatten(flat_inputs)
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
torch.manual_seed(seed)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed_all(SEED)
with torch.inference_mode():
out = lerobot_model.get_action(model_inputs)
lerobot_action = out["action_pred"].float().cpu()
assert lerobot_action.shape == original_action.shape, (
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
"The same checkpoint and inputs must produce identical shapes; this indicates an "
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
"utils/dump_original_n1_7.py)."
)
t = min(original_action.shape[1], lerobot_action.shape[1])
d = min(original_action.shape[2], lerobot_action.shape[2])
original_action = original_action[:, :t, :d]
lerobot_action = lerobot_action[:, :t, :d]
diff = torch.abs(lerobot_action - original_action)
max_diff = diff.max().item()
@@ -301,56 +205,3 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
)
@_requires_artifacts
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
def test_groot_preprocessor_parity(embodiment_tag, artifact):
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
template, tokenizer and image packing plus the checkpoint-driven state
normalization (no mocks) -- on the raw observations recorded in the artifact, and
compares every collated model input against the ones the original ``gr00t``
processor produced from the same raw observations.
"""
raw = _load_raw_observation(artifact)
if raw is None:
pytest.skip(
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
"not record raw observations; regenerate it with the current dump script to run the "
"preprocessor parity case."
)
_, flat_inputs, _ = _load_artifact(artifact)
original_inputs = _unflatten(flat_inputs)
ckpt = _resolve_checkpoint()
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
preprocessor, _ = make_groot_pre_post_processors(config)
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
missing_original = [k for k in compared_keys if k not in original_inputs]
missing_lerobot = [k for k in compared_keys if k not in processed]
assert not missing_original, (
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
)
assert not missing_lerobot, (
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
)
for name in compared_keys:
_assert_collated_parity(
embodiment_tag,
name,
processed[name],
original_inputs[name],
exact=name in _COLLATED_EXACT_KEYS,
)
@@ -9,9 +9,6 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
imported in the same Python process. To keep the parity comparison FAIR, we run the
original model in its native env here and serialize, PER EMBODIMENT TAG:
* the RAW observation fed to the original processor (per-camera uint8 frames,
per-key state vectors, the language instruction), so the LeRobot side can also
run its OWN preprocessor on identical raw inputs and compare collated tensors,
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
byte-identical tensors -- same image preprocessing, tokenization, normalization),
* the random seed used right before the flow-matching sampler,
@@ -24,10 +21,8 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
``libero_sim``.
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
(model parity), and the raw observation is replayed through LeRobot's own
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
Usage:
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
@@ -67,7 +62,10 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
# present-but-empty so the processor's state transform finds every expected key.
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
state = {
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
for k, dim in state_spec
}
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
return {"video": video, "state": state, "language": language}
@@ -79,25 +77,6 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
lang_key = modality_cfg["language"].modality_keys[0]
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
# Snapshot the RAW observation exactly as fed to the original processor below. The
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
# and compares the resulting collated tensors against the "in::" ones saved further
# down. raw_state_keys records the checkpoint modality-key order, which is the
# concatenation order of the flat LeRobot ``observation.state`` vector.
spec_keys = [key for key, _ in state_spec]
state_modality = modality_cfg.get("state")
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
state_keys += [key for key in spec_keys if key not in state_keys]
raw_language = [
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
for item in observation["language"][lang_key]
]
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
policy.modality_configs = {
@@ -157,7 +136,6 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
embodiment_tag=np.array(tag),
meta_keys=np.array(list(meta.keys()), dtype=object),
meta_dtypes=np.array(list(meta.values()), dtype=object),
**raw_flat,
**flat,
)
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
@@ -203,12 +181,7 @@ def main():
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
try:
dump_one_tag(
policy,
fair_model,
tag,
all_modality[tag],
state_spec,
args,
policy, fair_model, tag, all_modality[tag], state_spec, args,
out_dir / f"original_n1_7_{tag}.npz",
)
done.append(tag)