diff --git a/tests/policies/groot/test_groot_vs_original.py b/tests/policies/groot/test_groot_vs_original.py index a46915b45..0cbdb877d 100644 --- a/tests/policies/groot/test_groot_vs_original.py +++ b/tests/policies/groot/test_groot_vs_original.py @@ -14,31 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot. +"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot. -Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action -head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the -normalized flow-matching prediction before any action decoding) as NVIDIA's original -``gr00t`` package, given byte-identical pre-processed inputs and the same -flow-matching seed. The comparison is parametrized over every embodiment tag present -in the checkpoint. +Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced +once in the original ``gr00t`` env by the companion script +``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file): -To keep the comparison fair, the original outputs + the exact collated inputs are -produced once per embodiment in the original ``gr00t`` env via the companion script -``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved -to per-tag ``.npz`` files. -This test discovers those artifacts, replays the identical inputs through the LeRobot -model, and compares. +1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7 + action head + Qwen3-VL backbone must produce the SAME raw model output + (``action_pred``, the normalized flow-matching prediction before any action + decoding) as NVIDIA's original ``gr00t`` package, given byte-identical + pre-processed inputs and the flow-matching seed recorded in the artifact. +2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat + template / tokenizer / image packing + state normalization, no mocks) must produce + the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...) + as the original package's processor, given the identical raw observations + (images, state, language) recorded in the artifact. Artifacts written by older + versions of the dump script carry no raw observations; this case then SKIPS with + a regeneration hint. -This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not -present, or when no artifact has been generated. By default it looks for artifacts in +These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not +present, or when no artifact has been generated. By default they look for artifacts in ``/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the "Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md`` for the full run procedure. """ import os +import warnings from pathlib import Path +from typing import Any import numpy as np import pytest @@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif( ) from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401 +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 +# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field. SEED = 42 DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3")) @@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3")) _ARTIFACT_PREFIX = "original_n1_7_" _ARTIFACT_SUFFIX = ".npz" +# Collated keys compared by the preprocessor parity case: integer/id tensors must +# match exactly; float tensors within ATOL/RTOL. +_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id") +_COLLATED_CLOSE_KEYS = ("pixel_values", "state") + def _artifact_dir() -> Path: """Directory holding the per-embodiment .npz artifacts. @@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str: return str(ckpt) -def _load_artifact(path: Path): +def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]: + """Return (original action_pred, collated model inputs, flow-matching seed).""" data = np.load(path, allow_pickle=True) original_action = torch.from_numpy(data["action_pred"]).float() + if "seed" in data.files: + seed = int(data["seed"]) + else: + warnings.warn( + f"Artifact '{path.name}' does not record the producer seed (it predates the current " + f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, " + "regenerate the artifact with the current dump script.", + stacklevel=2, + ) + seed = SEED dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False)) inputs = {} for key in data.files: @@ -124,7 +147,45 @@ def _load_artifact(path: Path): if "int" in declared or "long" in declared: t = t.long() inputs[name] = t - return original_action, inputs + return original_action, inputs, seed + + +def _load_raw_observation(path: Path) -> dict[str, Any] | None: + """Return the raw observation recorded in the artifact, or None for old artifacts. + + Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the + exact raw observation the producer fed to the original processor: per-camera uint8 + frames (``raw::video.``, (B, T, H, W, C)), per-key state vectors + (``raw::state.``, (B, T, dim)) and the language instruction + (``raw::language``, one string per batch element). ``raw_video_keys`` / + ``raw_state_keys`` record the checkpoint modality-key order. + """ + data = np.load(path, allow_pickle=True) + markers = ("raw_video_keys", "raw_state_keys", "raw::language") + if any(marker not in data.files for marker in markers): + return None + video_keys = [str(k) for k in data["raw_video_keys"].tolist()] + state_keys = [str(k) for k in data["raw_state_keys"].tolist()] + return { + "video": {k: data[f"raw::video.{k}"] for k in video_keys}, + "state": {k: data[f"raw::state.{k}"] for k in state_keys}, + "language": [str(t) for t in data["raw::language"].tolist()], + } + + +def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]: + """Convert the producer's raw observation into a LeRobot policy batch.""" + batch: dict[str, Any] = {} + for key, frames in raw["video"].items(): + # (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly. + batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous() + # observation.state is the per-key state vectors (latest frame) concatenated in + # checkpoint modality-key order -- the layout the LeRobot pack step and the + # flattened checkpoint statistics expect. + state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()] + batch[OBS_STATE] = torch.cat(state_parts, dim=-1) + batch["task"] = list(raw["language"]) + return batch def _unflatten(inputs: dict[str, torch.Tensor]) -> dict: @@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict: return nested.get("inputs", nested) +def _assert_collated_parity( + embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool +) -> None: + """Compare one collated tensor produced by LeRobot against the original's.""" + assert isinstance(lerobot_value, torch.Tensor), ( + f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is " + f"{type(lerobot_value).__name__}, expected a tensor." + ) + lerobot_t = lerobot_value.detach().cpu() + original_t = original_value.detach().cpu() + assert lerobot_t.shape == original_t.shape, ( + f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs " + f"original={tuple(original_t.shape)}." + ) + if exact: + mismatched = int((lerobot_t.long() != original_t.long()).sum()) + assert mismatched == 0, ( + f"[{embodiment_tag}] collated '{name}' differs from the original processor output: " + f"{mismatched}/{original_t.numel()} elements mismatch." + ) + else: + lerobot_f, original_f = lerobot_t.float(), original_t.float() + max_diff = (lerobot_f - original_f).abs().max().item() + print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}") + assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), ( + f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond " + f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}." + ) + + @pytest.fixture(scope="module") def lerobot_model(): """Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags.""" @@ -165,8 +256,7 @@ def lerobot_model(): _ARTIFACTS = _discover_artifacts() - -@pytest.mark.skipif( +_requires_artifacts = pytest.mark.skipif( not _ARTIFACTS, reason=( "No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t " @@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts() "--ckpt --out-dir tests/policies/groot/artifacts --device cuda" ), ) + + +@_requires_artifacts @pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS]) def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model): """Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot.""" - original_action, flat_inputs = _load_artifact(artifact) + original_action, flat_inputs, seed = _load_artifact(artifact) model_inputs = _unflatten(flat_inputs) # Align the flow-matching RNG exactly as the producer did (seed right before sampling). - torch.manual_seed(SEED) + torch.manual_seed(seed) if torch.cuda.is_available(): - torch.cuda.manual_seed_all(SEED) + torch.cuda.manual_seed_all(seed) with torch.inference_mode(): out = lerobot_model.get_action(model_inputs) lerobot_action = out["action_pred"].float().cpu() - t = min(original_action.shape[1], lerobot_action.shape[1]) - d = min(original_action.shape[2], lerobot_action.shape[2]) - original_action = original_action[:, :t, :d] - lerobot_action = lerobot_action[:, :t, :d] + assert lerobot_action.shape == original_action.shape, ( + f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': " + f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. " + "The same checkpoint and inputs must produce identical shapes; this indicates an " + "action-horizon or action-dim regression (or a stale artifact -- regenerate it with " + "utils/dump_original_n1_7.py)." + ) diff = torch.abs(lerobot_action - original_action) max_diff = diff.max().item() @@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model): f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond " f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}" ) + + +@_requires_artifacts +@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS]) +def test_groot_preprocessor_parity(embodiment_tag, artifact): + """LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs. + + Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat + template, tokenizer and image packing plus the checkpoint-driven state + normalization (no mocks) -- on the raw observations recorded in the artifact, and + compares every collated model input against the ones the original ``gr00t`` + processor produced from the same raw observations. + """ + raw = _load_raw_observation(artifact) + if raw is None: + pytest.skip( + f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does " + "not record raw observations; regenerate it with the current dump script to run the " + "preprocessor parity case." + ) + _, flat_inputs, _ = _load_artifact(artifact) + original_inputs = _unflatten(flat_inputs) + + ckpt = _resolve_checkpoint() + from lerobot.policies.groot.configuration_groot import GrootConfig + from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors + + # CPU keeps this case runnable without a GPU; the preprocessor is deterministic. + config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu") + preprocessor, _ = make_groot_pre_post_processors(config) + + processed = preprocessor(_raw_observation_to_lerobot_batch(raw)) + + compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS) + missing_original = [k for k in compared_keys if k not in original_inputs] + missing_lerobot = [k for k in compared_keys if k not in processed] + assert not missing_original, ( + f"[{embodiment_tag}] artifact collated inputs miss {missing_original} " + f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script." + ) + assert not missing_lerobot, ( + f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys " + f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})." + ) + + for name in compared_keys: + _assert_collated_parity( + embodiment_tag, + name, + processed[name], + original_inputs[name], + exact=name in _COLLATED_EXACT_KEYS, + ) diff --git a/tests/policies/groot/utils/dump_original_n1_7.py b/tests/policies/groot/utils/dump_original_n1_7.py index 26d1cd10c..47ba8f611 100644 --- a/tests/policies/groot/utils/dump_original_n1_7.py +++ b/tests/policies/groot/utils/dump_original_n1_7.py @@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno imported in the same Python process. To keep the parity comparison FAIR, we run the original model in its native env here and serialize, PER EMBODIMENT TAG: + * the RAW observation fed to the original processor (per-camera uint8 frames, + per-key state vectors, the language instruction), so the LeRobot side can also + run its OWN preprocessor on identical raw inputs and compare collated tensors, * the exact pre-processed/collated model inputs (so the LeRobot side consumes the byte-identical tensors -- same image preprocessing, tokenization, normalization), * the random seed used right before the flow-matching sampler, @@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t from the SAME checkpoint and confirm the LeRobot integration is not overfit to ``libero_sim``. -The companion pytest (run in the LeRobot env) loads each .npz, replays the identical -inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match. +The companion pytest (run in the LeRobot env) loads each .npz and asserts parity +twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model +(model parity), and the raw observation is replayed through LeRobot's own +preprocessor pipeline and compared against the collated inputs (preprocessor parity). Usage: .venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \ @@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec): # One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics. # Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as # present-but-empty so the processor's state transform finds every expected key. - state = { - k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) - for k, dim in state_spec - } + state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec} language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]} return {"video": video, "state": state, "language": language} @@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa lang_key = modality_cfg["language"].modality_keys[0] observation = make_observation(args.seed, video_keys, lang_key, state_spec) + # Snapshot the RAW observation exactly as fed to the original processor below. The + # consumer's preprocessor-parity case replays it through LeRobot's own preprocessor + # and compares the resulting collated tensors against the "in::" ones saved further + # down. raw_state_keys records the checkpoint modality-key order, which is the + # concatenation order of the flat LeRobot ``observation.state`` vector. + spec_keys = [key for key, _ in state_spec] + state_modality = modality_cfg.get("state") + state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else [] + state_keys += [key for key in spec_keys if key not in state_keys] + raw_language = [ + str(item[0]) if isinstance(item, (list, tuple)) else str(item) + for item in observation["language"][lang_key] + ] + raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()} + raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()}) + raw_flat["raw::language"] = np.array(raw_language, dtype=object) + raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object) + raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object) + # Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__). policy.embodiment_tag = type(policy.embodiment_tag)(tag) policy.modality_configs = { @@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa embodiment_tag=np.array(tag), meta_keys=np.array(list(meta.keys()), dtype=object), meta_dtypes=np.array(list(meta.values()), dtype=object), + **raw_flat, **flat, ) print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)") @@ -181,7 +203,12 @@ def main(): state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()] try: dump_one_tag( - policy, fair_model, tag, all_modality[tag], state_spec, args, + policy, + fair_model, + tag, + all_modality[tag], + state_spec, + args, out_dir / f"original_n1_7_{tag}.npz", ) done.append(tag)