mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
test(groot): parametrize N1.7 parity across all checkpoint embodiments
Generalize the original-vs-LeRobot N1.7 output-parity test from a single libero_sim case to every embodiment tag in the checkpoint (libero_sim, oxe_droid, real_g1, the real_r1_pro_sharpa family, and the xdof family). Inputs are built generically from checkpoint metadata; the test discovers per-tag .npz artifacts and runs one parametrized case each, loading the LeRobot model once via a fixture. All 9 embodiments match the original to fp32 epsilon (max|diff| < 3e-6), confirming the integration is correct across the model's full embodiment space and not overfit to libero_sim.
This commit is contained in:
@@ -23,6 +23,16 @@ the normalized flow-matching prediction before any action decoding) as NVIDIA's
|
||||
original ``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed.
|
||||
|
||||
MULTIPLE EMBODIMENTS (anti-overfitting)
|
||||
---------------------------------------
|
||||
The comparison is parametrized over EVERY embodiment tag present in the checkpoint
|
||||
(``libero_sim`` plus the cross-embodiment tags it was trained with: oxe_droid,
|
||||
real_g1, the real_r1_pro_sharpa family, and the xdof family). Inputs for each tag are
|
||||
built generically from the checkpoint metadata (state dims from ``statistics.json``,
|
||||
camera/language keys from the processor modality configs), so passing on all of them
|
||||
shows the LeRobot integration is correct across the model's full embodiment space and
|
||||
not merely tuned for ``libero_sim``.
|
||||
|
||||
WHY TWO ENVIRONMENTS
|
||||
--------------------
|
||||
The original ``gr00t`` package pins ``transformers==4.57.3`` (Python 3.10) and its
|
||||
@@ -33,16 +43,16 @@ argument follows default argument"). The two implementations therefore CANNOT be
|
||||
imported in the same Python process.
|
||||
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once in the original ``gr00t`` env via
|
||||
``groot_vs_lerobot/scripts/dump_original_n1_7.py`` and saved to an ``.npz``. This
|
||||
test loads that artifact, replays the identical inputs through the LeRobot model,
|
||||
and compares.
|
||||
produced once per embodiment in the original ``gr00t`` env via
|
||||
``groot_vs_lerobot/scripts/dump_original_n1_7.py`` and saved to per-tag ``.npz``
|
||||
files. This test discovers those artifacts, replays the identical inputs through the
|
||||
LeRobot model, and compares.
|
||||
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when the artifact has not been generated. No hardcoded paths: the
|
||||
artifact location comes from ``GROOT_N1_7_PARITY_NPZ`` (default:
|
||||
``groot_vs_lerobot/artifacts/original_n1_7_libero.npz`` relative to the repo root).
|
||||
See ``groot_vs_lerobot/README.md`` for the full run procedure.
|
||||
present, or when no artifact has been generated. No hardcoded paths: the artifact
|
||||
directory comes from ``GROOT_N1_7_PARITY_DIR`` (default:
|
||||
``groot_vs_lerobot/artifacts`` alongside the repo root). See
|
||||
``groot_vs_lerobot/README.md`` for the full run procedure.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -54,26 +64,40 @@ import torch
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="Requires a local GR00T N1.7 checkpoint + a pre-generated artifact; not for CI.",
|
||||
reason="Requires a local GR00T N1.7 checkpoint + pre-generated artifacts; not for CI.",
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
|
||||
EMBODIMENT_TAG = "libero_sim"
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
|
||||
# Artifact filenames are original_n1_7_<embodiment_tag>.npz
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
|
||||
def _artifact_path() -> Path:
|
||||
env = os.environ.get("GROOT_N1_7_PARITY_NPZ")
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
env = os.environ.get("GROOT_N1_7_PARITY_DIR")
|
||||
if env:
|
||||
return Path(env)
|
||||
# repo_root/tests/policies/groot/<this file> -> repo_root parent holds groot_vs_lerobot/
|
||||
# repo_root/tests/policies/groot/<this file> -> repo_root.parent holds groot_vs_lerobot/
|
||||
repo_root = Path(__file__).resolve().parents[3]
|
||||
# The companion workspace lives alongside the repo, not inside it.
|
||||
return repo_root.parent / "groot_vs_lerobot" / "artifacts" / "original_n1_7_libero.npz"
|
||||
return repo_root.parent / "groot_vs_lerobot" / "artifacts"
|
||||
|
||||
|
||||
def _discover_artifacts() -> list[tuple[str, Path]]:
|
||||
"""Return [(embodiment_tag, npz_path), ...] for every dumped artifact."""
|
||||
d = _artifact_dir()
|
||||
if not d.is_dir():
|
||||
return []
|
||||
out = []
|
||||
for p in sorted(d.glob(f"{_ARTIFACT_PREFIX}*{_ARTIFACT_SUFFIX}")):
|
||||
tag = p.name[len(_ARTIFACT_PREFIX) : -len(_ARTIFACT_SUFFIX)]
|
||||
out.append((tag, p))
|
||||
return out
|
||||
|
||||
|
||||
def _resolve_checkpoint() -> str:
|
||||
@@ -99,12 +123,6 @@ def _resolve_checkpoint() -> str:
|
||||
|
||||
|
||||
def _load_artifact(path: Path):
|
||||
if not path.exists():
|
||||
pytest.skip(
|
||||
f"Parity artifact not found at {path}. Generate it first in the original gr00t "
|
||||
f"env:\n .venv-original/bin/python groot_vs_lerobot/scripts/dump_original_n1_7.py "
|
||||
f"--ckpt <ckpt> --out {path} --device cuda --seed {SEED}"
|
||||
)
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
@@ -115,7 +133,6 @@ def _load_artifact(path: Path):
|
||||
name = key[4:]
|
||||
arr = data[key]
|
||||
t = torch.from_numpy(np.asarray(arr))
|
||||
# Restore integer dtypes that np may have widened.
|
||||
declared = dtypes.get(key, "")
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
@@ -132,25 +149,15 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
for p in parts[:-1]:
|
||||
cur = cur.setdefault(p, {})
|
||||
cur[parts[-1]] = value
|
||||
# The producer flattened the top-level collated dict; "inputs" is its only branch.
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
def test_groot_n1_7_get_action_parity():
|
||||
"""Raw model.get_action(action_pred) parity: original gr00t vs LeRobot integration."""
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
ckpt = _resolve_checkpoint()
|
||||
original_action, flat_inputs = _load_artifact(_artifact_path())
|
||||
|
||||
# Load the underlying GR00T N1.7 model directly (mirrors the original side, which
|
||||
# calls ``policy.model.get_action``). This bypasses the LeRobot policy feature
|
||||
# pipeline so the comparison is strictly between the two model reimplementations
|
||||
# on identical pre-processed inputs.
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
# Run fp32 + SDPA on the LeRobot side to match the producer exactly (the original
|
||||
# artifact is dumped fp32 + SDPA). bf16 + differing attention kernels otherwise
|
||||
# introduce ~1e-2 numerical noise unrelated to the implementations.
|
||||
dtype = torch.float32
|
||||
model = GR00TN17.from_pretrained(
|
||||
ckpt,
|
||||
tune_llm=False,
|
||||
@@ -160,11 +167,30 @@ def test_groot_n1_7_get_action_parity():
|
||||
tune_vlln=False,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
# fp32 + SDPA on both sides: bf16 + differing attention kernels otherwise introduce
|
||||
# ~1e-2 numerical noise unrelated to the implementations.
|
||||
model.compute_dtype = "float32"
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
model.to(device=DEVICE, dtype=dtype)
|
||||
model.to(device=DEVICE, dtype=torch.float32)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
"env:\n .venv-original/bin/python groot_vs_lerobot/scripts/dump_original_n1_7.py "
|
||||
"--ckpt <ckpt> --out-dir groot_vs_lerobot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_n1_7_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
@@ -172,7 +198,7 @@ def test_groot_n1_7_get_action_parity():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
with torch.inference_mode():
|
||||
out = model.get_action(model_inputs)
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
@@ -181,17 +207,14 @@ def test_groot_n1_7_get_action_parity():
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
print(f"\nShapes: lerobot={tuple(lerobot_action.shape)} original={tuple(original_action.shape)}")
|
||||
print(f"{'idx':<5}{'LeRobot':>14}{'Original':>14}{'|diff|':>14}")
|
||||
for di in range(min(8, lerobot_action.shape[-1])):
|
||||
lr = lerobot_action[0, 0, di].item()
|
||||
og = original_action[0, 0, di].item()
|
||||
print(f"{di:<5}{lr:>14.6f}{og:>14.6f}{abs(lr - og):>14.6f}")
|
||||
max_diff = diff.max().item()
|
||||
print(f"\nmax|diff| = {max_diff:.6e} mean|diff| = {diff.mean().item():.6e}")
|
||||
print(
|
||||
f"\n[{embodiment_tag}] shapes lerobot={tuple(lerobot_action.shape)} "
|
||||
f"original={tuple(original_action.shape)} "
|
||||
f"max|diff|={max_diff:.6e} mean|diff|={diff.mean().item():.6e}"
|
||||
)
|
||||
|
||||
assert torch.allclose(lerobot_action, original_action, atol=ATOL, rtol=RTOL), (
|
||||
f"GR00T N1.7 raw action_pred differs beyond atol={ATOL}, rtol={RTOL}: "
|
||||
f"max|diff|={max_diff:.6e}"
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
)
|
||||
print(f"\nSUCCESS: GR00T N1.7 raw outputs match (max|diff|={max_diff:.6e}, atol={ATOL})")
|
||||
|
||||
Reference in New Issue
Block a user