mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-15 07:19:51 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f81fc79564 | |||
| 9ce6633518 |
@@ -321,6 +321,9 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
@@ -362,7 +365,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Groot-specific model parameters
|
||||
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||
# Kept only so config.json files saved with earlier versions still parse.
|
||||
image_size: tuple[int, int] = (256, 256)
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
@@ -378,6 +385,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
@@ -416,13 +428,10 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
video_backend: str = "decord"
|
||||
balance_dataset_weights: bool = True
|
||||
balance_trajectory_weights: bool = True
|
||||
@@ -436,6 +445,9 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||
# N1.5 checkpoint or config is being loaded.
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
@@ -570,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
"""Return indices for delta actions.
|
||||
|
||||
The model action horizon is read from the checkpoint's processor_config.json
|
||||
when available; the result is cached (keyed on the inputs that determine it) so
|
||||
repeated access during dataset/training setup does not re-read from disk.
|
||||
"""
|
||||
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||
return indices
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -71,7 +71,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
@@ -822,6 +822,8 @@ def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
|
||||
# marker, so trust the explicit backbone type but surface what is being assumed.
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
@@ -912,6 +914,10 @@ class GR00TN17(PreTrainedModel):
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
|
||||
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
|
||||
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
|
||||
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
@@ -93,6 +93,12 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||
# bf16 preference when it is enabled instead of reading a default back.
|
||||
if self.config.use_bf16:
|
||||
model.compute_dtype = "bfloat16"
|
||||
model.config.compute_dtype = "bfloat16"
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -23,10 +23,9 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2.functional as tv_functional
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -59,7 +58,6 @@ from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
|
||||
from .configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
@@ -450,40 +448,60 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
||||
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
||||
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
|
||||
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
|
||||
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
|
||||
def _legacy_groot_processor_overrides(
|
||||
config: GrootConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
|
||||
preprocessor_overrides: dict[str, Any] | None = None,
|
||||
postprocessor_overrides: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Patch older serialized Groot processors with fields current processors expect."""
|
||||
|
||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
|
||||
|
||||
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
||||
pack_input_overrides["normalize_min_max"] = True
|
||||
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
||||
|
||||
try:
|
||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||
except Exception:
|
||||
env_action_dim = 0
|
||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
|
||||
action_unpack_overrides["normalize_min_max"] = True
|
||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
|
||||
|
||||
return preprocessor_overrides, postprocessor_overrides
|
||||
|
||||
|
||||
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
|
||||
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||
"""Check whether a serialized processor pipeline contains a registry step.
|
||||
|
||||
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
|
||||
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
|
||||
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
|
||||
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
|
||||
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
|
||||
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
|
||||
(e.g. a typo) is left in place and still raises.
|
||||
Resolves the processor config from a local directory or, for Hub repo ids,
|
||||
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
|
||||
False when the config cannot be resolved; loading then proceeds with the
|
||||
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
|
||||
without them if they do not match the serialized pipeline.
|
||||
"""
|
||||
|
||||
if not overrides:
|
||||
return overrides
|
||||
|
||||
filtered: dict[str, Any] = {}
|
||||
for key, value in overrides.items():
|
||||
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
|
||||
logging.debug(
|
||||
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
|
||||
"no matching step (see GrootConfig.normalization_mapping).",
|
||||
key,
|
||||
path = Path(pretrained_path).expanduser()
|
||||
if path.is_dir():
|
||||
config = _read_json(path / config_filename)
|
||||
elif path.exists():
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
|
||||
)
|
||||
continue
|
||||
filtered[key] = value
|
||||
return filtered
|
||||
except Exception:
|
||||
return False
|
||||
config = _read_json(Path(config_path))
|
||||
steps = config.get("steps", [])
|
||||
if not isinstance(steps, list):
|
||||
return False
|
||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||
|
||||
|
||||
def _apply_groot_step_overrides(
|
||||
@@ -499,8 +517,7 @@ def _apply_groot_step_overrides(
|
||||
steps by registry name only — prefer registry names so overrides keep
|
||||
working after the checkpoint is converted and reloaded from a serialized
|
||||
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||
silently (standard normalization keys GR00T has no step for are removed
|
||||
beforehand by ``_drop_groot_absent_standard_overrides``).
|
||||
silently.
|
||||
"""
|
||||
|
||||
if not overrides:
|
||||
@@ -556,13 +573,7 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
|
||||
|
||||
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
|
||||
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
|
||||
# paths raise. This must happen before either branch below.
|
||||
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
|
||||
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
|
||||
"""Load Groot processors while preserving compatibility with older serialized configs."""
|
||||
|
||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||
processor_cfg = copy(config)
|
||||
@@ -578,13 +589,49 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||
if _pretrained_processor_config_has_step(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
postprocessor_config_filename,
|
||||
"groot_n1_7_action_decode_v1",
|
||||
):
|
||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||
# action decoder. Adding the legacy action-unpack override would target
|
||||
# a step that is not present and break loading.
|
||||
applied_legacy_overrides = False
|
||||
preprocessor_overrides = caller_preprocessor_overrides
|
||||
postprocessor_overrides = caller_postprocessor_overrides
|
||||
else:
|
||||
applied_legacy_overrides = True
|
||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||
config=config,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
)
|
||||
try:
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
postprocessor_overrides=postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
except KeyError:
|
||||
if not applied_legacy_overrides:
|
||||
raise
|
||||
# The legacy overrides target steps that are absent from the serialized
|
||||
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
|
||||
# config could not be inspected before loading); retry with the caller
|
||||
# overrides only.
|
||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||
pretrained_path,
|
||||
preprocessor_overrides=caller_preprocessor_overrides,
|
||||
postprocessor_overrides=caller_postprocessor_overrides,
|
||||
preprocessor_config_filename=preprocessor_config_filename,
|
||||
postprocessor_config_filename=postprocessor_config_filename,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -747,10 +794,6 @@ def make_groot_pre_post_processors(
|
||||
use_albumentations=checkpoint_assets.use_albumentations
|
||||
if checkpoint_assets is not None
|
||||
else False,
|
||||
# Run the image resize/normalize/patchify on the training device when
|
||||
# possible instead of the single CPU main-loop thread (the dominant
|
||||
# cost folded into dataloading_s).
|
||||
device=config.device,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
@@ -989,61 +1032,6 @@ def _transform_n1_7_image_for_vlm(
|
||||
return image
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_torch(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
image_crop_size: list[int] | None,
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
) -> torch.Tensor:
|
||||
"""Torch/torchvision port of the non-albumentations branch of
|
||||
:func:`_transform_n1_7_image_for_vlm`.
|
||||
|
||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
||||
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
||||
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
|
||||
target_h, target_w = image_target_size
|
||||
_, height, width = image.shape
|
||||
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
image = tv_functional.resize(
|
||||
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
|
||||
if crop_fraction is None and image_crop_size is not None:
|
||||
crop_fraction = image_crop_size[0] / float(target_h)
|
||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
||||
# Match the PIL helper's center crop exactly: round() the crop size but
|
||||
# floor() the offset (torchvision.center_crop rounds the offset, which
|
||||
# shifts the region by 1px when (edge - crop) is odd).
|
||||
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
|
||||
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
|
||||
top = max(0, (image.shape[-2] - crop_h) // 2)
|
||||
left = max(0, (image.shape[-1] - crop_w) // 2)
|
||||
image = image[..., top : top + crop_h, left : left + crop_w]
|
||||
|
||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||
image = tv_functional.resize(
|
||||
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
|
||||
class GrootN17PackInputsStep(ProcessorStep):
|
||||
@@ -1070,6 +1058,9 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
video_modality_keys: list[str] | None = None
|
||||
raw_stats: dict[str, Any] | None = None
|
||||
modality_config: dict[str, Any] | None = None
|
||||
# Unused: kept so serialized configs that include it still load. The raw
|
||||
# state cache is per instance (_last_raw_state), never process-global.
|
||||
state_cache_key: str = ""
|
||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@@ -1342,12 +1333,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
||||
an image item in the same chat message so the resulting image tokens match
|
||||
the temporal VLM packing used by Isaac-GR00T.
|
||||
|
||||
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
||||
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
||||
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
|
||||
single CPU main-loop thread. This keeps the output bit-identical on CPU and
|
||||
moves the dominant preprocessing cost off the critical path on GPU.
|
||||
"""
|
||||
|
||||
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
@@ -1356,7 +1341,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
shortest_image_edge: int | None = None
|
||||
crop_fraction: float | None = None
|
||||
use_albumentations: bool = False
|
||||
device: str | None = None
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@property
|
||||
@@ -1365,70 +1349,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
self._proc = _build_n1_7_processor(self.model_name)
|
||||
return self._proc
|
||||
|
||||
def _target_device(self) -> torch.device | None:
|
||||
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
||||
if self.device is None or self.use_albumentations:
|
||||
return None
|
||||
try:
|
||||
return get_safe_torch_device(self.device)
|
||||
except (AssertionError, RuntimeError):
|
||||
# A device serialized at train time (e.g. "cuda") may be unavailable
|
||||
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
|
||||
# this step is not in the standard device-override set. Fall back to
|
||||
# the CPU path, which is bit-identical, instead of crashing.
|
||||
return None
|
||||
|
||||
def _build_sample_images(
|
||||
self, video: Any, batch_size: int, target_device: torch.device | None
|
||||
) -> list[list[Any]]:
|
||||
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
||||
|
||||
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
||||
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
||||
``target_device`` when set) for the torchvision-backed Qwen processor.
|
||||
"""
|
||||
if self.use_albumentations:
|
||||
video_np = np.asarray(video)
|
||||
return [
|
||||
[
|
||||
_transform_n1_7_image_for_vlm(
|
||||
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
use_albumentations=True,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
]
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
|
||||
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
||||
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
||||
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
|
||||
if target_device is not None and video_t.device != target_device:
|
||||
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
|
||||
|
||||
frames_per_sample: list[list[Any]] = []
|
||||
for batch_idx in range(batch_size):
|
||||
sample = video_t[batch_idx] # (T, V, C, H, W)
|
||||
frames_per_sample.append(
|
||||
[
|
||||
_transform_n1_7_image_for_vlm_torch(
|
||||
sample[timestep, view_idx],
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
]
|
||||
)
|
||||
return frames_per_sample
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
@@ -1436,25 +1356,33 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
if video is None:
|
||||
return transition
|
||||
|
||||
batch_size = int(video.shape[0])
|
||||
languages = _prepare_n1_7_language_batch(
|
||||
comp.get("language"),
|
||||
batch_size,
|
||||
video.shape[0],
|
||||
formalize_language=False,
|
||||
)
|
||||
|
||||
target_device = self._target_device()
|
||||
sample_images = self._build_sample_images(video, batch_size, target_device)
|
||||
|
||||
texts: list[str] = []
|
||||
images: list[Any] = []
|
||||
for batch_idx in range(batch_size):
|
||||
frames = sample_images[batch_idx]
|
||||
images: list[Image.Image] = []
|
||||
for batch_idx in range(video.shape[0]):
|
||||
sample = video[batch_idx] # (T, V, H, W, C)
|
||||
sample_images = [
|
||||
_transform_n1_7_image_for_vlm(
|
||||
Image.fromarray(sample[timestep, view_idx]),
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
use_albumentations=self.use_albumentations,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
]
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*[{"type": "image", "image": image} for image in frames],
|
||||
*[{"type": "image", "image": image} for image in sample_images],
|
||||
{"type": "text", "text": languages[batch_idx]},
|
||||
],
|
||||
}
|
||||
@@ -1466,17 +1394,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
)
|
||||
images.extend(frames)
|
||||
images.extend(sample_images)
|
||||
|
||||
proc_kwargs: dict[str, Any] = {
|
||||
"text": texts,
|
||||
"images": images,
|
||||
"return_tensors": "pt",
|
||||
"padding": True,
|
||||
}
|
||||
if target_device is not None:
|
||||
proc_kwargs["device"] = str(target_device)
|
||||
encoded = self.proc(**proc_kwargs)
|
||||
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
for key, value in encoded.items():
|
||||
comp[key] = value
|
||||
obs.pop("video", None)
|
||||
@@ -1495,7 +1415,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"shortest_image_edge": self.shortest_image_edge,
|
||||
"crop_fraction": self.crop_fraction,
|
||||
"use_albumentations": self.use_albumentations,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
|
||||
@@ -1646,6 +1565,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
modality_config: dict[str, Any] | None = None
|
||||
use_percentiles: bool = False
|
||||
use_relative_action: bool = False
|
||||
# Unused: kept so serialized configs that include it still load.
|
||||
state_cache_key: str = ""
|
||||
action_decode_transform: str | None = None
|
||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||
|
||||
@@ -1773,10 +1694,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
env_action_dim: int = 0
|
||||
|
||||
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
assert isinstance(lerobot_loss, torch.Tensor)
|
||||
assert torch.isfinite(lerobot_loss).all()
|
||||
assert "loss" in lerobot_metrics
|
||||
assert np.isfinite(lerobot_metrics["loss"])
|
||||
|
||||
print("\nForward pass successful.")
|
||||
print(f" - Loss: {lerobot_loss.item():.6f}")
|
||||
print(f" - Metrics: {lerobot_metrics}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,31 +14,36 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
|
||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
||||
in the checkpoint.
|
||||
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
|
||||
once in the original ``gr00t`` env by the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
|
||||
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
||||
to per-tag ``.npz`` files.
|
||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
||||
model, and compares.
|
||||
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
|
||||
action head + Qwen3-VL backbone must produce the SAME raw model output
|
||||
(``action_pred``, the normalized flow-matching prediction before any action
|
||||
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
|
||||
pre-processed inputs and the flow-matching seed recorded in the artifact.
|
||||
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
|
||||
template / tokenizer / image packing + state normalization, no mocks) must produce
|
||||
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
|
||||
as the original package's processor, given the identical raw observations
|
||||
(images, state, language) recorded in the artifact. Artifacts written by older
|
||||
versions of the dump script carry no raw observations; this case then SKIPS with
|
||||
a regeneration hint.
|
||||
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
||||
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default they look for artifacts in
|
||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||
for the full run procedure.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
|
||||
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
|
||||
# Collated keys compared by the preprocessor parity case: integer/id tensors must
|
||||
# match exactly; float tensors within ATOL/RTOL.
|
||||
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
|
||||
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
|
||||
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
"""Directory holding the per-embodiment .npz artifacts.
|
||||
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
|
||||
return str(ckpt)
|
||||
|
||||
|
||||
def _load_artifact(path: Path):
|
||||
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
|
||||
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
if "seed" in data.files:
|
||||
seed = int(data["seed"])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
|
||||
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
|
||||
"regenerate the artifact with the current dump script.",
|
||||
stacklevel=2,
|
||||
)
|
||||
seed = SEED
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
inputs = {}
|
||||
for key in data.files:
|
||||
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
inputs[name] = t
|
||||
return original_action, inputs
|
||||
return original_action, inputs, seed
|
||||
|
||||
|
||||
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
|
||||
"""Return the raw observation recorded in the artifact, or None for old artifacts.
|
||||
|
||||
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
|
||||
exact raw observation the producer fed to the original processor: per-camera uint8
|
||||
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
|
||||
(``raw::state.<key>``, (B, T, dim)) and the language instruction
|
||||
(``raw::language``, one string per batch element). ``raw_video_keys`` /
|
||||
``raw_state_keys`` record the checkpoint modality-key order.
|
||||
"""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
|
||||
if any(marker not in data.files for marker in markers):
|
||||
return None
|
||||
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
|
||||
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
|
||||
return {
|
||||
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
|
||||
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
|
||||
"language": [str(t) for t in data["raw::language"].tolist()],
|
||||
}
|
||||
|
||||
|
||||
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert the producer's raw observation into a LeRobot policy batch."""
|
||||
batch: dict[str, Any] = {}
|
||||
for key, frames in raw["video"].items():
|
||||
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
|
||||
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
|
||||
# observation.state is the per-key state vectors (latest frame) concatenated in
|
||||
# checkpoint modality-key order -- the layout the LeRobot pack step and the
|
||||
# flattened checkpoint statistics expect.
|
||||
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
|
||||
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
|
||||
batch["task"] = list(raw["language"])
|
||||
return batch
|
||||
|
||||
|
||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
def _assert_collated_parity(
|
||||
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
|
||||
) -> None:
|
||||
"""Compare one collated tensor produced by LeRobot against the original's."""
|
||||
assert isinstance(lerobot_value, torch.Tensor), (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
|
||||
f"{type(lerobot_value).__name__}, expected a tensor."
|
||||
)
|
||||
lerobot_t = lerobot_value.detach().cpu()
|
||||
original_t = original_value.detach().cpu()
|
||||
assert lerobot_t.shape == original_t.shape, (
|
||||
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
|
||||
f"original={tuple(original_t.shape)}."
|
||||
)
|
||||
if exact:
|
||||
mismatched = int((lerobot_t.long() != original_t.long()).sum())
|
||||
assert mismatched == 0, (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
|
||||
f"{mismatched}/{original_t.numel()} elements mismatch."
|
||||
)
|
||||
else:
|
||||
lerobot_f, original_f = lerobot_t.float(), original_t.float()
|
||||
max_diff = (lerobot_f - original_f).abs().max().item()
|
||||
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
|
||||
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
@@ -165,8 +256,7 @@ def lerobot_model():
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_requires_artifacts = pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
|
||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
original_action, flat_inputs, seed = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
torch.manual_seed(SEED)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
with torch.inference_mode():
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
||||
original_action = original_action[:, :t, :d]
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
assert lerobot_action.shape == original_action.shape, (
|
||||
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
|
||||
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
|
||||
"The same checkpoint and inputs must produce identical shapes; this indicates an "
|
||||
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
|
||||
"utils/dump_original_n1_7.py)."
|
||||
)
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
max_diff = diff.max().item()
|
||||
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_preprocessor_parity(embodiment_tag, artifact):
|
||||
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
|
||||
|
||||
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
|
||||
template, tokenizer and image packing plus the checkpoint-driven state
|
||||
normalization (no mocks) -- on the raw observations recorded in the artifact, and
|
||||
compares every collated model input against the ones the original ``gr00t``
|
||||
processor produced from the same raw observations.
|
||||
"""
|
||||
raw = _load_raw_observation(artifact)
|
||||
if raw is None:
|
||||
pytest.skip(
|
||||
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
|
||||
"not record raw observations; regenerate it with the current dump script to run the "
|
||||
"preprocessor parity case."
|
||||
)
|
||||
_, flat_inputs, _ = _load_artifact(artifact)
|
||||
original_inputs = _unflatten(flat_inputs)
|
||||
|
||||
ckpt = _resolve_checkpoint()
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
|
||||
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
|
||||
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||
|
||||
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
|
||||
|
||||
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
|
||||
missing_original = [k for k in compared_keys if k not in original_inputs]
|
||||
missing_lerobot = [k for k in compared_keys if k not in processed]
|
||||
assert not missing_original, (
|
||||
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
|
||||
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
|
||||
)
|
||||
assert not missing_lerobot, (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
|
||||
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
|
||||
)
|
||||
|
||||
for name in compared_keys:
|
||||
_assert_collated_parity(
|
||||
embodiment_tag,
|
||||
name,
|
||||
processed[name],
|
||||
original_inputs[name],
|
||||
exact=name in _COLLATED_EXACT_KEYS,
|
||||
)
|
||||
|
||||
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
|
||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||
|
||||
* the RAW observation fed to the original processor (per-camera uint8 frames,
|
||||
per-key state vectors, the language instruction), so the LeRobot side can also
|
||||
run its OWN preprocessor on identical raw inputs and compare collated tensors,
|
||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||
* the random seed used right before the flow-matching sampler,
|
||||
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
|
||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||
``libero_sim``.
|
||||
|
||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
||||
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
|
||||
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
|
||||
(model parity), and the raw observation is replayed through LeRobot's own
|
||||
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
|
||||
|
||||
Usage:
|
||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
lang_key = modality_cfg["language"].modality_keys[0]
|
||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||
|
||||
# Snapshot the RAW observation exactly as fed to the original processor below. The
|
||||
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
|
||||
# and compares the resulting collated tensors against the "in::" ones saved further
|
||||
# down. raw_state_keys records the checkpoint modality-key order, which is the
|
||||
# concatenation order of the flat LeRobot ``observation.state`` vector.
|
||||
spec_keys = [key for key, _ in state_spec]
|
||||
state_modality = modality_cfg.get("state")
|
||||
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
|
||||
state_keys += [key for key in spec_keys if key not in state_keys]
|
||||
raw_language = [
|
||||
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
||||
for item in observation["language"][lang_key]
|
||||
]
|
||||
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
|
||||
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
|
||||
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
|
||||
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
|
||||
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
|
||||
|
||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||
policy.modality_configs = {
|
||||
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
embodiment_tag=np.array(tag),
|
||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||
**raw_flat,
|
||||
**flat,
|
||||
)
|
||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||
@@ -181,7 +203,12 @@ def main():
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
policy,
|
||||
fair_model,
|
||||
tag,
|
||||
all_modality[tag],
|
||||
state_spec,
|
||||
args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
|
||||
Reference in New Issue
Block a user