Compare commits

..

1 Commits

Author SHA1 Message Date
Steven Palma ffdfb3d25f fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
Covers the GR00T N1.7 source trio (configuration, processor, model wrapper).
These three files are grouped together because processor_groot and
modeling_groot import GROOT_N1_5_REMOVAL_GUIDANCE defined in
configuration_groot.

Config:
- GrootConfig defaults are the N1.7 values; explicitly passed legacy N1.5-era
  values (chunk_size=50, max_state_dim=64, ...) are remapped with a warning
  instead of silently.
- action_decode_transform gains an 'auto' sentinel so an explicit 'none'
  opt-out wins over the libero_sim default and survives save/load round-trips.
- action_delta_indices is cached on the inputs that determine it.
- Legacy N1.5 checkpoints/configs (tokenizer_assets_repo, model_type/
  architectures/eagle backbone markers) are rejected with a single clear
  error pointing to lerobot==0.5.1.

Processor:
- GrootN17ActionDecodeStep handles the 2-D (B, D) actions delivered by sync
  select_action (relative eef/non-eef decode in eval/record flows).
- Postprocessor falls back to dataset stats when a raw checkpoint lacks the
  configured embodiment tag; raw-state cache is per-instance, not
  process-global; caller overrides (device, rename_map) are honored on the
  raw-checkpoint branch.
- Hub-hosted finetuned checkpoints resolve the processor config via
  hf_hub_download, with a tolerant retry when inspection fails.
- Camera/modality-key mismatches warn (including the zero-match fallback);
  deprecated Qwen2VLImageProcessorFast replaced with Qwen2VLImageProcessor;
  removed N1.5 processor steps are stubbed to raise the removal guidance and
  the action-unpack step is re-registered as _v2.

Model:
- use_bf16=False no longer crashes (compute_dtype only set when used).
- Flash-attention probe is diagnostic-only; forward raises on a missing loss;
  print() replaced with logging; N1.5 base-path mismatch includes the
  removal guidance.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-12 23:37:44 +02:00
7 changed files with 206 additions and 691 deletions
+1 -4
View File
@@ -4,9 +4,6 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
@@ -136,7 +133,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
hf download nvidia/GR00T-N1.7-LIBERO \
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
+23 -52
View File
@@ -1,13 +1,6 @@
## Research Paper
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
> Current releases support GR00T N1.7 only.
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
@@ -38,22 +31,12 @@ Hugging Face Models:
## Original-vs-LeRobot parity test
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
over every embodiment tag present in the checkpoint:
1. **Model parity** — given byte-identical pre-processed inputs and the same
flow-matching seed (recorded in each artifact), both implementations must produce
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
flow-matching prediction). Output shapes must match exactly; any action-horizon
or action-dim mismatch fails the test.
2. **Preprocessor parity** — given the identical raw observations (per-camera
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
state normalization, no mocks) must produce the **same collated model inputs**
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
`embodiment_id`) as the original package's processor.
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
byte-identical pre-processed inputs and the same flow-matching seed. It is
parametrized over every embodiment tag present in the checkpoint.
### Why two environments
@@ -65,37 +48,25 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
So the test uses a **producer / consumer** split across two venvs:
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
1. **Producer**`tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
gr00t venv. For each embodiment it builds dummy inputs generically from the
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
the processor modality configs), runs the original model, and saves to one `.npz`
per tag: the raw observations (`raw::` keys), the exact collated inputs
(`in::` keys), the seed, and the raw `action_pred`.
2. **Consumer** the pytest above, run in the _LeRobot_ venv. It discovers every
`.npz`; the model-parity case replays the byte-identical collated inputs through
the LeRobot model with the recorded seed and asserts the outputs match, and the
preprocessor-parity case replays the raw observations through LeRobot's full
preprocessor pipeline and asserts the collated tensors match.
> Artifacts generated by older versions of the dump script contain no `raw::`
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
> Re-run the producer to refresh them.
the processor modality configs), runs the original model, and saves the exact
collated inputs + raw `action_pred` to one `.npz` per tag.
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
seed, and asserts the outputs match.
### Fairness controls
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
- **Same pre-processed inputs** — the original processor's `input_ids`,
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
model comparison isolates the model. LeRobot's own tokenization / image packing is
covered separately by the preprocessor-parity case, which compares its output
against those same collated tensors from identical raw observations.
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
kernel/rounding noise, not an implementation difference.)
- **Same flow-matching seed** — fixed right before sampling on both sides; the
producer records it in each artifact (`--seed`, default 42) and the consumer
replays the recorded value.
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
### How to run
@@ -119,15 +90,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
```
The `.npz` artifacts are local-only (gitignored, ~610 MB each) and are regenerated by
the producer; they are never committed. The tests **skip** (do not fail) on CI or
The `.npz` artifacts are local-only (gitignored, ~69 MB each) and are regenerated by
the producer; they are never committed. The test **skips** (does not fail) on CI or
when the checkpoint / artifacts are absent.
#### Env knobs (all optional)
| Var | Default | Purpose |
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
| Var | Default | Purpose |
|---|---|---|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
@@ -14,7 +14,6 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import torch
@@ -43,9 +42,6 @@ else:
Timesteps = None
logger = logging.getLogger(__name__)
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
require_package("diffusers", extra="groot")
@@ -269,8 +265,8 @@ class DiT(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
logger.debug(
"Total number of DiT parameters: %d",
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -430,8 +426,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
for _ in range(self.config.num_layers)
]
)
logger.debug(
"Total number of SelfAttentionTransformer parameters: %d",
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
@@ -42,14 +42,6 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
@@ -329,6 +321,9 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
# instead of being silently treated as N1.7 (see infer_groot_model_version).
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
@@ -370,7 +365,11 @@ class GrootConfig(PreTrainedConfig):
}
)
# Groot-specific model parameters
# Deprecated and unused: image sizing is handled by the backbone's image processor.
# Kept only so config.json files saved with earlier versions still parse.
image_size: tuple[int, int] = (256, 256)
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
@@ -386,43 +385,14 @@ class GrootConfig(PreTrainedConfig):
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
tokenizer_assets_repo: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
@@ -458,13 +428,10 @@ class GrootConfig(PreTrainedConfig):
warmup_ratio: float = 0.05
use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
video_backend: str = "decord"
balance_dataset_weights: bool = True
balance_trajectory_weights: bool = True
@@ -478,6 +445,9 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
# serialized into every groot checkpoint config.json, so a value here means a legacy
# N1.5 checkpoint or config is being loaded.
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -612,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
"""Return indices for delta actions.
The model action horizon is read from the checkpoint's processor_config.json
when available; the result is cached (keyed on the inputs that determine it) so
repeated access during dataset/training setup does not re-read from disk.
"""
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
cached = getattr(self, "_action_delta_indices_cache", None)
if cached is not None and cached[0] == cache_key:
return cached[1]
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
indices = list(range(min(self.chunk_size, model_action_horizon)))
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
return indices
@property
def reward_delta_indices(self) -> None:
+9 -13
View File
@@ -32,7 +32,6 @@ from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
@@ -72,13 +71,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 16,
"select_layer": 12,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
@@ -820,14 +819,11 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
def get_backbone_cls(config: GR00TN17Config):
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
return Qwen3Backbone
if config.backbone_model_type == "qwen":
logger.warning(
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
"backbone because backbone_model_type='qwen'.",
config.model_name,
)
if (
config.backbone_model_type == "qwen"
or "nvidia/Cosmos-Reason2" in config.model_name
or "Qwen/Qwen3-VL" in config.model_name
):
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
@@ -913,7 +909,7 @@ class GR00TN17(PreTrainedModel):
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("cache_dir", "local_files_only", "token"):
for key in ("revision", "cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
+6 -125
View File
@@ -54,98 +54,6 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
@@ -185,24 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
# Inference-only override for the number of flow-matching denoising steps. The action
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
# GR00TN17 defines no compute_dtype attribute, so only record the
# bf16 preference when it is enabled instead of reading a default back.
if self.config.use_bf16:
model.compute_dtype = "bfloat16"
model.config.compute_dtype = "bfloat16"
return model
@@ -371,11 +266,7 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
# An explicit config override caps the open-loop horizon (inference cadence), overriding
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
@@ -543,16 +434,6 @@ class GrootPolicy(PreTrainedPolicy):
"""
self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
+133 -440
View File
@@ -23,10 +23,9 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available
@@ -47,8 +46,6 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
batch_to_transition,
policy_action_to_transition,
to_absolute_actions,
to_relative_actions,
transition_to_batch,
transition_to_policy_action,
)
@@ -61,14 +58,11 @@ from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
from lerobot.utils.device_utils import get_safe_torch_device
from .configuration_groot import (
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
GROOT_N1_5_REMOVAL_GUIDANCE,
GROOT_N1_7_BACKBONE_MODEL,
N1_7_DEFAULT_IMAGE_CROP_SIZE,
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
GrootConfig,
is_raw_groot_n1_7_checkpoint,
)
@@ -90,30 +84,6 @@ N1_7_EMBODIMENT_MAPPING = {
}
_GROOT_REF_HOLDER_KEY = "_groot_relative_ref_holder" # private; dropped by _filter_groot_inputs, never reaches the model
class _GrootRelativeRefHolder:
"""Runtime-only carrier shared (by object identity) between the pack step (owner/writer of the
live reference), GrootPolicy.predict_action_chunk (freezes it at a real predict event), and the
decode step (reads the frozen reference). Not serialized. One instance per pack step."""
__slots__ = ("reference_state", "raw_state", "frozen_reference", "frozen_raw")
def __init__(self):
self.reference_state = None
self.raw_state = None
self.frozen_reference = None
self.frozen_raw = None
def freeze(self) -> None:
self.frozen_reference = self.reference_state
self.frozen_raw = self.raw_state
def clear(self) -> None:
self.reference_state = self.raw_state = self.frozen_reference = self.frozen_raw = None
@dataclass
class _GrootN17CheckpointProcessorAssets:
"""Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint.
@@ -143,39 +113,6 @@ class _GrootN17CheckpointProcessorAssets:
use_albumentations: bool
def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None:
"""Resolve a base model path to a local snapshot dir holding its sidecar JSONs.
``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a
bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7
checkpoint and the processor would fall back to LeRobot default image geometry instead of the
checkpoint's processor_config.json geometry. When the path is not already a local dir, this
downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure
returns the original string unchanged. Only used on the fresh-build (training) path; inference
loads the serialized processor, so no per-inference network call is added.
"""
if base_model_path is None:
return None
if Path(base_model_path).expanduser().is_dir():
return base_model_path
try:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
base_model_path,
repo_type="model",
allow_patterns=["*.json"],
)
logging.debug(
"Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.",
base_model_path,
local_dir,
)
return local_dir
except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure)
return base_model_path
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
"""Load N1.7 processor settings from checkpoint sidecar JSON files.
@@ -183,11 +120,10 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
can keep using caller-provided dataset stats and config values.
"""
resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path)
if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path):
if not is_raw_groot_n1_7_checkpoint(config.base_model_path):
return None
checkpoint_path = Path(resolved_base_model_path).expanduser()
checkpoint_path = Path(config.base_model_path).expanduser()
processor_config = _read_json(checkpoint_path / "processor_config.json")
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
@@ -512,74 +448,60 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
def _build_relative_action_mask(
action_dim: int,
exclude_joints: list[str] | None,
action_names: list[str] | None,
) -> list[bool]:
"""Build the per-dim relative-action mask (True = convert to relative, False = keep absolute).
def _legacy_groot_processor_overrides(
config: GrootConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
preprocessor_overrides: dict[str, Any] | None = None,
postprocessor_overrides: dict[str, Any] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Patch older serialized Groot processors with fields current processors expect."""
Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded
(kept absolute) by case-insensitive token match against ``action_names``.
preprocessor_overrides = dict(preprocessor_overrides or {})
postprocessor_overrides = dict(postprocessor_overrides or {})
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
When ``action_names`` is None we cannot identify the gripper, so this returns all-True
(every dim treated as relative). The user should ensure ``config.action_feature_names`` is
populated (the factory does this from dataset meta) so the gripper can be kept absolute;
arm-relative still works either way, but a missing-name gripper would be treated as relative.
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
try:
env_action_dim = int(config.output_features[ACTION].shape[0])
except Exception:
env_action_dim = 0
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
return preprocessor_overrides, postprocessor_overrides
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
"""Check whether a serialized processor pipeline contains a registry step.
Resolves the processor config from a local directory or, for Hub repo ids,
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
False when the config cannot be resolved; loading then proceeds with the
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
without them if they do not match the serialized pipeline.
"""
if not exclude_joints or action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask: list[bool] = []
for name in action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
(e.g. a typo) is left in place and still raises.
"""
if not overrides:
return overrides
filtered: dict[str, Any] = {}
for key, value in overrides.items():
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
logging.debug(
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
"no matching step (see GrootConfig.normalization_mapping).",
key,
path = Path(pretrained_path).expanduser()
if path.is_dir():
config = _read_json(path / config_filename)
elif path.exists():
return False
else:
try:
config_path = hf_hub_download(
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
)
continue
filtered[key] = value
return filtered
except Exception:
return False
config = _read_json(Path(config_path))
steps = config.get("steps", [])
if not isinstance(steps, list):
return False
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
def _apply_groot_step_overrides(
@@ -595,8 +517,7 @@ def _apply_groot_step_overrides(
steps by registry name only — prefer registry names so overrides keep
working after the checkpoint is converted and reloaded from a serialized
pipeline). Keys or fields that match nothing raise instead of being dropped
silently (standard normalization keys GR00T has no step for are removed
beforehand by ``_drop_groot_absent_standard_overrides``).
silently.
"""
if not overrides:
@@ -652,13 +573,7 @@ def make_groot_pre_post_processors_from_pretrained(
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
# paths raise. This must happen before either branch below.
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
"""Load Groot processors while preserving compatibility with older serialized configs."""
if is_raw_groot_n1_7_checkpoint(pretrained_path):
processor_cfg = copy(config)
@@ -674,13 +589,49 @@ def make_groot_pre_post_processors_from_pretrained(
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
return preprocessor, postprocessor
preprocessor, postprocessor = _load_groot_processor_pipelines(
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
if _pretrained_processor_config_has_step(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
postprocessor_config_filename,
"groot_n1_7_action_decode_v1",
):
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
# action decoder. Adding the legacy action-unpack override would target
# a step that is not present and break loading.
applied_legacy_overrides = False
preprocessor_overrides = caller_preprocessor_overrides
postprocessor_overrides = caller_postprocessor_overrides
else:
applied_legacy_overrides = True
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
config=config,
dataset_stats=dataset_stats,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
)
try:
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=preprocessor_overrides,
postprocessor_overrides=postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
except KeyError:
if not applied_legacy_overrides:
raise
# The legacy overrides target steps that are absent from the serialized
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
# config could not be inspected before loading); retry with the caller
# overrides only.
preprocessor, postprocessor = _load_groot_processor_pipelines(
pretrained_path,
preprocessor_overrides=caller_preprocessor_overrides,
postprocessor_overrides=caller_postprocessor_overrides,
preprocessor_config_filename=preprocessor_config_filename,
postprocessor_config_filename=postprocessor_config_filename,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -747,15 +698,8 @@ def _reconnect_groot_n1_7_pack_decode_steps(
if pack_step is None:
return
# Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link:
# GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path
# (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after
# deserialization.
for step in postprocessor.steps:
if (
isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep))
and step.pack_step is None
):
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
step.pack_step = pack_step
@@ -833,45 +777,23 @@ def make_groot_pre_post_processors(
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
)
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
# when it provides an image_target_size; otherwise fall back to the geometry the
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
# frames, inflating the VLM token count -- slowing both dataloading_s and update_s --
# and feeding the model a resolution it was not trained on.
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
image_target_size = checkpoint_assets.image_target_size
image_crop_size = checkpoint_assets.image_crop_size
shortest_image_edge = checkpoint_assets.shortest_image_edge
crop_fraction = checkpoint_assets.crop_fraction
else:
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
shortest_image_edge = None
crop_fraction = None
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
pack_step,
GrootN17VLMEncodeStep(
model_name=config.n1_7_backbone_model,
image_crop_size=image_crop_size,
image_target_size=image_target_size,
shortest_image_edge=shortest_image_edge,
crop_fraction=crop_fraction,
use_albumentations=use_albumentations,
# Run the image resize/normalize/patchify on the training device when
# possible instead of the single CPU main-loop thread (the dominant
# cost folded into dataloading_s).
device=config.device,
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
shortest_image_edge=checkpoint_assets.shortest_image_edge
if checkpoint_assets is not None
else None,
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
use_albumentations=checkpoint_assets.use_albumentations
if checkpoint_assets is not None
else False,
),
DeviceProcessorStep(device=config.device),
]
@@ -895,10 +817,6 @@ def make_groot_pre_post_processors(
stats=padded_stats,
normalize_min_max=True,
clip_normalized_action=True,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
pack_step=pack_step,
)
else:
action_decode_step = GrootN17ActionDecodeStep(
@@ -1114,61 +1032,6 @@ def _transform_n1_7_image_for_vlm(
return image
def _transform_n1_7_image_for_vlm_torch(
image: torch.Tensor,
*,
image_crop_size: list[int] | None,
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
) -> torch.Tensor:
"""Torch/torchvision port of the non-albumentations branch of
:func:`_transform_n1_7_image_for_vlm`.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
"""
if image_target_size is None:
return image
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
)
if crop_fraction is None and image_crop_size is not None:
crop_fraction = image_crop_size[0] / float(target_h)
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
# Match the PIL helper's center crop exactly: round() the crop size but
# floor() the offset (torchvision.center_crop rounds the offset, which
# shifts the region by 1px when (edge - crop) is odd).
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
top = max(0, (image.shape[-2] - crop_h) // 2)
left = max(0, (image.shape[-1] - crop_w) // 2)
image = image[..., top : top + crop_h, left : left + crop_w]
if tuple(image.shape[-2:]) != (target_h, target_w):
image = tv_functional.resize(
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
)
return image
@dataclass
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
class GrootN17PackInputsStep(ProcessorStep):
@@ -1195,18 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Opt-in relative-action support: convert absolute->relative actions inside this pack step
# (training) using the cached raw reference state, keeping excluded joints (e.g. gripper)
# absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode.
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
# Unused: kept so serialized configs that include it still load. The raw
# state cache is per instance (_last_raw_state), never process-global.
state_cache_key: str = ""
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
_ref_holder: "_GrootRelativeRefHolder" = field(
default_factory=_GrootRelativeRefHolder, init=False, repr=False
)
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
available = {key for key in obs if key.startswith(OBS_IMAGES)}
@@ -1328,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
start_idx += dim
if grouped:
self._last_raw_state = grouped
self._ref_holder.raw_state = grouped
img_keys = self._ordered_image_keys(obs)
if img_keys:
@@ -1348,9 +1203,6 @@ class GrootN17PackInputsStep(ProcessorStep):
formalize_language=self.formalize_language,
)
# Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached
# regardless of whether an action is present so inference caches it too for decode.
relative_reference_state: torch.Tensor | None = None
if OBS_STATE in obs:
state = obs[OBS_STATE]
if state.dim() != 2:
@@ -1359,10 +1211,6 @@ class GrootN17PackInputsStep(ProcessorStep):
if dim > self.max_state_dim:
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
_cache_raw_state(state)
if self.use_relative_actions:
relative_reference_state = state.detach().clone()
self._last_reference_state = relative_reference_state
self._ref_holder.reference_state = relative_reference_state
if self.normalize_min_max:
state = _min_max_norm(state, OBS_STATE)
state = state.unsqueeze(1)
@@ -1385,19 +1233,6 @@ class GrootN17PackInputsStep(ProcessorStep):
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
if dim > self.max_action_dim:
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
# Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g.
# gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T.
if self.use_relative_actions:
if relative_reference_state is None:
raise RuntimeError(
"GrootN17PackInputsStep.use_relative_actions requires observation.state "
"(OBS_STATE) to be present alongside the action to build the relative "
"reference, but no state was found in this transition."
)
mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_relative_actions(action, relative_reference_state, mask)
if self.normalize_min_max:
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
action = flat.view(bsz, horizon, dim)
@@ -1437,12 +1272,6 @@ class GrootN17PackInputsStep(ProcessorStep):
comp["action_mask"] = action_mask
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device)
# Publish the runtime-only reference holder so the policy can freeze it at the predict
# event and the decode step can read the frozen reference. It rides in COMPLEMENTARY_DATA,
# survives the VLM-encode step and DeviceProcessorStep as a non-tensor, and reaches the
# policy via the batch (by object identity) through the pipeline's shallow copies.
comp[_GROOT_REF_HOLDER_KEY] = self._ref_holder
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
@@ -1467,9 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
"video_modality_keys": self.video_modality_keys,
"raw_stats": self.raw_stats,
"modality_config": self.modality_config,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
}
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
@@ -1477,23 +1303,6 @@ class GrootN17PackInputsStep(ProcessorStep):
return self._last_raw_state
def get_cached_reference_state(self) -> torch.Tensor | None:
"""Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion."""
return self._last_reference_state
def get_reference_holder(self) -> "_GrootRelativeRefHolder":
"""Return the runtime-only holder shared with the policy (writer) and decode step (reader)."""
return self._ref_holder
def reset(self) -> None:
"""Clear cached per-episode relative-action references (sync engine resets on episode boundaries)."""
self._last_reference_state = None
self._last_raw_state = None
self._ref_holder.clear()
def state_dict(self) -> dict[str, torch.Tensor]:
if not self.stats:
return {}
@@ -1524,12 +1333,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
an image item in the same chat message so the resulting image tokens match
the temporal VLM packing used by Isaac-GR00T.
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
single CPU main-loop thread. This keeps the output bit-identical on CPU and
moves the dominant preprocessing cost off the critical path on GPU.
"""
model_name: str = GROOT_N1_7_BACKBONE_MODEL
@@ -1538,7 +1341,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None
crop_fraction: float | None = None
use_albumentations: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
@@ -1547,70 +1349,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
self._proc = _build_n1_7_processor(self.model_name)
return self._proc
def _target_device(self) -> torch.device | None:
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
if self.device is None or self.use_albumentations:
return None
try:
return get_safe_torch_device(self.device)
except (AssertionError, RuntimeError):
# A device serialized at train time (e.g. "cuda") may be unavailable
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
# this step is not in the standard device-override set. Fall back to
# the CPU path, which is bit-identical, instead of crashing.
return None
def _build_sample_images(
self, video: Any, batch_size: int, target_device: torch.device | None
) -> list[list[Any]]:
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
``target_device`` when set) for the torchvision-backed Qwen processor.
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm(
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=True,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
if target_device is not None and video_t.device != target_device:
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
frames_per_sample: list[list[Any]] = []
for batch_idx in range(batch_size):
sample = video_t[batch_idx] # (T, V, C, H, W)
frames_per_sample.append(
[
_transform_n1_7_image_for_vlm_torch(
sample[timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
)
return frames_per_sample
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
@@ -1618,25 +1356,33 @@ class GrootN17VLMEncodeStep(ProcessorStep):
if video is None:
return transition
batch_size = int(video.shape[0])
languages = _prepare_n1_7_language_batch(
comp.get("language"),
batch_size,
video.shape[0],
formalize_language=False,
)
target_device = self._target_device()
sample_images = self._build_sample_images(video, batch_size, target_device)
texts: list[str] = []
images: list[Any] = []
for batch_idx in range(batch_size):
frames = sample_images[batch_idx]
images: list[Image.Image] = []
for batch_idx in range(video.shape[0]):
sample = video[batch_idx] # (T, V, H, W, C)
sample_images = [
_transform_n1_7_image_for_vlm(
Image.fromarray(sample[timestep, view_idx]),
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
use_albumentations=self.use_albumentations,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
]
conversation = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in frames],
*[{"type": "image", "image": image} for image in sample_images],
{"type": "text", "text": languages[batch_idx]},
],
}
@@ -1648,17 +1394,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
add_generation_prompt=False,
)
)
images.extend(frames)
images.extend(sample_images)
proc_kwargs: dict[str, Any] = {
"text": texts,
"images": images,
"return_tensors": "pt",
"padding": True,
}
if target_device is not None:
proc_kwargs["device"] = str(target_device)
encoded = self.proc(**proc_kwargs)
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
for key, value in encoded.items():
comp[key] = value
obs.pop("video", None)
@@ -1677,7 +1415,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations,
"device": self.device,
}
@@ -1828,6 +1565,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
modality_config: dict[str, Any] | None = None
use_percentiles: bool = False
use_relative_action: bool = False
# Unused: kept so serialized configs that include it still load.
state_cache_key: str = ""
action_decode_transform: str | None = None
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
@@ -1890,14 +1629,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
start_idx += dim
if self.use_relative_action:
# Prefer the raw state frozen at the chunk-prediction event (see the relative-action
# branch of GrootActionUnpackUnnormalizeStep). Falls back to the live cached raw state.
holder = self.pack_step.get_reference_holder() if self.pack_step is not None else None
raw_state = None
if holder is not None:
raw_state = holder.frozen_raw if holder.frozen_raw is not None else holder.raw_state
if raw_state is None and self.pack_step is not None:
raw_state = self.pack_step.get_cached_raw_state()
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
if raw_state is None:
raise RuntimeError(
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
@@ -1962,10 +1694,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
}
@dataclass
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
# action chunks to the last timestep, so old serialized v1 pipelines must not
# silently load into it (v1 is stubbed below with the removal guidance).
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
@@ -1975,13 +1707,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
clip_normalized_action: bool = False
libero_gripper_action: bool = False
libero_gripper_binarize: bool = True
# Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the
# min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute
# using the reference state cached by the linked pack_step (re-linked on reload).
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
@@ -2021,35 +1746,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
inv = (action + 1.0) * 0.5 * safe_denom + min_v
action = torch.where(mask, inv, min_v)
# Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the
# reference state cached by the linked pack step. The link is restored on reload by
# _reconnect_groot_n1_7_pack_decode_steps.
if self.use_relative_actions:
if self.pack_step is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked "
"GrootN17PackInputsStep to read the cached reference state, but pack_step is None. "
"Build both pipelines through make_groot_pre_post_processors (or load them together "
"via make_groot_pre_post_processors_from_pretrained)."
)
# Prefer the reference frozen at the chunk-prediction event (set by
# GrootPolicy.predict_action_chunk via the shared holder) so every popped delta of a
# chunk reconstructs against that chunk's start state S_T, not the per-tick latest
# state. Falls back to the live reference when nothing was frozen (e.g. decode without
# a preceding predict event, or RTC/async where frozen == live).
holder = self.pack_step.get_reference_holder()
ref = holder.frozen_reference if holder.frozen_reference is not None else holder.reference_state
if ref is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state "
"cached by its connected GrootN17PackInputsStep to convert relative actions back to "
"absolute. Run the preprocessor on an observation before decoding actions."
)
relative_mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_absolute_actions(action, ref, relative_mask)
if self.libero_gripper_action and action.shape[-1] >= 7:
gripper = action[..., -1]
if self.libero_gripper_binarize:
@@ -2077,9 +1773,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
"clip_normalized_action": self.clip_normalized_action,
"libero_gripper_action": self.libero_gripper_action,
"libero_gripper_binarize": self.libero_gripper_binarize,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
}
def state_dict(self) -> dict[str, torch.Tensor]: