mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-14 23:09:54 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 404751ba8b | |||
| 559cba212d | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec |
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
||||
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -31,12 +38,22 @@ Hugging Face Models:
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
|
||||
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
|
||||
byte-identical pre-processed inputs and the same flow-matching seed. It is
|
||||
parametrized over every embodiment tag present in the checkpoint.
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves the exact
|
||||
collated inputs + raw `action_pred` to one `.npz` per tag.
|
||||
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
|
||||
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
|
||||
seed, and asserts the outputs match.
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs** — the original processor's `input_ids`,
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
@@ -90,15 +119,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–9 MB each) and are regenerated by
|
||||
the producer; they are never committed. The test **skips** (does not fail) on CI or
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -42,6 +43,9 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -321,9 +321,6 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
@@ -365,11 +362,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||
# Kept only so config.json files saved with earlier versions still parse.
|
||||
image_size: tuple[int, int] = (256, 256)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
# Groot-specific model parameters
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
@@ -385,11 +378,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
@@ -428,10 +416,13 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
video_backend: str = "decord"
|
||||
balance_dataset_weights: bool = True
|
||||
balance_trajectory_weights: bool = True
|
||||
@@ -445,9 +436,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||
# N1.5 checkpoint or config is being loaded.
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
@@ -582,22 +570,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions.
|
||||
|
||||
The model action horizon is read from the checkpoint's processor_config.json
|
||||
when available; the result is cached (keyed on the inputs that determine it) so
|
||||
repeated access during dataset/training setup does not re-read from disk.
|
||||
"""
|
||||
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||
return indices
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 12,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
@@ -819,11 +819,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if (
|
||||
config.backbone_model_type == "qwen"
|
||||
or "nvidia/Cosmos-Reason2" in config.model_name
|
||||
or "Qwen/Qwen3-VL" in config.model_name
|
||||
):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
@@ -909,7 +912,7 @@ class GR00TN17(PreTrainedModel):
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("revision", "cache_dir", "local_files_only", "token"):
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
|
||||
@@ -93,12 +93,6 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||
# bf16 preference when it is enabled instead of reading a default back.
|
||||
if self.config.use_bf16:
|
||||
model.compute_dtype = "bfloat16"
|
||||
model.config.compute_dtype = "bfloat16"
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -23,9 +23,10 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2.functional as tv_functional
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -58,6 +59,7 @@ from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
|
||||
from .configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
@@ -448,60 +450,40 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
def _legacy_groot_processor_overrides(
|
||||
config: GrootConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
|
||||
preprocessor_overrides: dict[str, Any] | None = None,
|
||||
postprocessor_overrides: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Patch older serialized Groot processors with fields current processors expect."""
|
||||
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
|
||||
|
||||
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
||||
pack_input_overrides["normalize_min_max"] = True
|
||||
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
||||
|
||||
try:
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
|
||||
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
||||
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
||||
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
|
||||
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
|
||||
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
|
||||
|
||||
|
||||
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
"""Check whether a serialized processor pipeline contains a registry step.
|
||||
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
|
||||
|
||||
Resolves the processor config from a local directory or, for Hub repo ids,
|
||||
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
|
||||
False when the config cannot be resolved; loading then proceeds with the
|
||||
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
|
||||
without them if they do not match the serialized pipeline.
|
||||
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
|
||||
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
|
||||
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
|
||||
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
|
||||
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
|
||||
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
|
||||
(e.g. a typo) is left in place and still raises.
|
||||
"""
|
||||
path = Path(pretrained_path).expanduser()
|
||||
if path.is_dir():
|
||||
config = _read_json(path / config_filename)
|
||||
elif path.exists():
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
|
||||
|
||||
if not overrides:
|
||||
return overrides
|
||||
|
||||
filtered: dict[str, Any] = {}
|
||||
for key, value in overrides.items():
|
||||
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
|
||||
logging.debug(
|
||||
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
|
||||
"no matching step (see GrootConfig.normalization_mapping).",
|
||||
key,
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
config = _read_json(Path(config_path))
|
||||
steps = config.get("steps", [])
|
||||
if not isinstance(steps, list):
|
||||
return False
|
||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||
continue
|
||||
filtered[key] = value
|
||||
return filtered
|
||||
|
||||
|
||||
def _apply_groot_step_overrides(
|
||||
@@ -517,7 +499,8 @@ def _apply_groot_step_overrides(
|
||||
steps by registry name only — prefer registry names so overrides keep
|
||||
working after the checkpoint is converted and reloaded from a serialized
|
||||
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||
silently.
|
||||
silently (standard normalization keys GR00T has no step for are removed
|
||||
beforehand by ``_drop_groot_absent_standard_overrides``).
|
||||
"""
|
||||
|
||||
if not overrides:
|
||||
@@ -573,7 +556,13 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Load Groot processors while preserving compatibility with older serialized configs."""
|
||||
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
|
||||
|
||||
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
|
||||
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
|
||||
# paths raise. This must happen before either branch below.
|
||||
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
|
||||
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
|
||||
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
processor_cfg = copy(config)
|
||||
@@ -589,49 +578,13 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
if _pretrained_processor_config_has_step(
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
postprocessor_config_filename,
|
||||
"groot_n1_7_action_decode_v1",
|
||||
):
|
||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||
# action decoder. Adding the legacy action-unpack override would target
|
||||
# a step that is not present and break loading.
|
||||
applied_legacy_overrides = False
|
||||
preprocessor_overrides = caller_preprocessor_overrides
|
||||
postprocessor_overrides = caller_postprocessor_overrides
|
||||
else:
|
||||
applied_legacy_overrides = True
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
try:
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
except KeyError:
|
||||
if not applied_legacy_overrides:
|
||||
raise
|
||||
# The legacy overrides target steps that are absent from the serialized
|
||||
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
|
||||
# config could not be inspected before loading); retry with the caller
|
||||
# overrides only.
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=caller_preprocessor_overrides,
|
||||
postprocessor_overrides=caller_postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -794,6 +747,10 @@ def make_groot_pre_post_processors(
|
||||
use_albumentations=checkpoint_assets.use_albumentations
|
||||
if checkpoint_assets is not None
|
||||
else False,
|
||||
# Run the image resize/normalize/patchify on the training device when
|
||||
# possible instead of the single CPU main-loop thread (the dominant
|
||||
# cost folded into dataloading_s).
|
||||
device=config.device,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
@@ -1032,6 +989,61 @@ def _transform_n1_7_image_for_vlm(
|
||||
return image
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_torch(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
image_crop_size: list[int] | None,
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
) -> torch.Tensor:
|
||||
"""Torch/torchvision port of the non-albumentations branch of
|
||||
:func:`_transform_n1_7_image_for_vlm`.
|
||||
|
||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
||||
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
||||
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
|
||||
target_h, target_w = image_target_size
|
||||
_, height, width = image.shape
|
||||
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
image = tv_functional.resize(
|
||||
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
|
||||
if crop_fraction is None and image_crop_size is not None:
|
||||
crop_fraction = image_crop_size[0] / float(target_h)
|
||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
||||
# Match the PIL helper's center crop exactly: round() the crop size but
|
||||
# floor() the offset (torchvision.center_crop rounds the offset, which
|
||||
# shifts the region by 1px when (edge - crop) is odd).
|
||||
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
|
||||
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
|
||||
top = max(0, (image.shape[-2] - crop_h) // 2)
|
||||
left = max(0, (image.shape[-1] - crop_w) // 2)
|
||||
image = image[..., top : top + crop_h, left : left + crop_w]
|
||||
|
||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||
image = tv_functional.resize(
|
||||
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
|
||||
class GrootN17PackInputsStep(ProcessorStep):
|
||||
@@ -1058,9 +1070,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
video_modality_keys: list[str] | None = None
|
||||
raw_stats: dict[str, Any] | None = None
|
||||
modality_config: dict[str, Any] | None = None
|
||||
# Unused: kept so serialized configs that include it still load. The raw
|
||||
# state cache is per instance (_last_raw_state), never process-global.
|
||||
state_cache_key: str = ""
|
||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@@ -1333,6 +1342,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
||||
an image item in the same chat message so the resulting image tokens match
|
||||
the temporal VLM packing used by Isaac-GR00T.
|
||||
|
||||
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
||||
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
||||
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
|
||||
single CPU main-loop thread. This keeps the output bit-identical on CPU and
|
||||
moves the dominant preprocessing cost off the critical path on GPU.
|
||||
"""
|
||||
|
||||
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
@@ -1341,6 +1356,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
shortest_image_edge: int | None = None
|
||||
crop_fraction: float | None = None
|
||||
use_albumentations: bool = False
|
||||
device: str | None = None
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
@@ -1349,6 +1365,70 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
self._proc = _build_n1_7_processor(self.model_name)
|
||||
return self._proc
|
||||
|
||||
def _target_device(self) -> torch.device | None:
|
||||
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
||||
if self.device is None or self.use_albumentations:
|
||||
return None
|
||||
try:
|
||||
return get_safe_torch_device(self.device)
|
||||
except (AssertionError, RuntimeError):
|
||||
# A device serialized at train time (e.g. "cuda") may be unavailable
|
||||
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
|
||||
# this step is not in the standard device-override set. Fall back to
|
||||
# the CPU path, which is bit-identical, instead of crashing.
|
||||
return None
|
||||
|
||||
def _build_sample_images(
|
||||
self, video: Any, batch_size: int, target_device: torch.device | None
|
||||
) -> list[list[Any]]:
|
||||
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
||||
|
||||
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
||||
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
||||
``target_device`` when set) for the torchvision-backed Qwen processor.
|
||||
"""
|
||||
if self.use_albumentations:
|
||||
video_np = np.asarray(video)
|
||||
return [
|
||||
[
|
||||
_transform_n1_7_image_for_vlm(
|
||||
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
use_albumentations=True,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
]
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
|
||||
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
||||
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
||||
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
|
||||
if target_device is not None and video_t.device != target_device:
|
||||
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
|
||||
|
||||
frames_per_sample: list[list[Any]] = []
|
||||
for batch_idx in range(batch_size):
|
||||
sample = video_t[batch_idx] # (T, V, C, H, W)
|
||||
frames_per_sample.append(
|
||||
[
|
||||
_transform_n1_7_image_for_vlm_torch(
|
||||
sample[timestep, view_idx],
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
]
|
||||
)
|
||||
return frames_per_sample
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
@@ -1356,33 +1436,25 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
if video is None:
|
||||
return transition
|
||||
|
||||
batch_size = int(video.shape[0])
|
||||
languages = _prepare_n1_7_language_batch(
|
||||
comp.get("language"),
|
||||
video.shape[0],
|
||||
batch_size,
|
||||
formalize_language=False,
|
||||
)
|
||||
|
||||
target_device = self._target_device()
|
||||
sample_images = self._build_sample_images(video, batch_size, target_device)
|
||||
|
||||
texts: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for batch_idx in range(video.shape[0]):
|
||||
sample = video[batch_idx] # (T, V, H, W, C)
|
||||
sample_images = [
|
||||
_transform_n1_7_image_for_vlm(
|
||||
Image.fromarray(sample[timestep, view_idx]),
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
use_albumentations=self.use_albumentations,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
]
|
||||
images: list[Any] = []
|
||||
for batch_idx in range(batch_size):
|
||||
frames = sample_images[batch_idx]
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*[{"type": "image", "image": image} for image in sample_images],
|
||||
*[{"type": "image", "image": image} for image in frames],
|
||||
{"type": "text", "text": languages[batch_idx]},
|
||||
],
|
||||
}
|
||||
@@ -1394,9 +1466,17 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
)
|
||||
images.extend(sample_images)
|
||||
images.extend(frames)
|
||||
|
||||
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
proc_kwargs: dict[str, Any] = {
|
||||
"text": texts,
|
||||
"images": images,
|
||||
"return_tensors": "pt",
|
||||
"padding": True,
|
||||
}
|
||||
if target_device is not None:
|
||||
proc_kwargs["device"] = str(target_device)
|
||||
encoded = self.proc(**proc_kwargs)
|
||||
for key, value in encoded.items():
|
||||
comp[key] = value
|
||||
obs.pop("video", None)
|
||||
@@ -1415,6 +1495,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"shortest_image_edge": self.shortest_image_edge,
|
||||
"crop_fraction": self.crop_fraction,
|
||||
"use_albumentations": self.use_albumentations,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
|
||||
@@ -1565,8 +1646,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
modality_config: dict[str, Any] | None = None
|
||||
use_percentiles: bool = False
|
||||
use_relative_action: bool = False
|
||||
# Unused: kept so serialized configs that include it still load.
|
||||
state_cache_key: str = ""
|
||||
action_decode_transform: str | None = None
|
||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||
|
||||
@@ -1694,10 +1773,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
|
||||
Reference in New Issue
Block a user