mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b2223c6162 |
@@ -14,31 +14,36 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||||
|
|
||||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
|
||||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
once in the original ``gr00t`` env by the companion script
|
||||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
|
||||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
|
||||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
|
||||||
in the checkpoint.
|
|
||||||
|
|
||||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
|
||||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
action head + Qwen3-VL backbone must produce the SAME raw model output
|
||||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
(``action_pred``, the normalized flow-matching prediction before any action
|
||||||
to per-tag ``.npz`` files.
|
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
|
||||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
pre-processed inputs and the flow-matching seed recorded in the artifact.
|
||||||
model, and compares.
|
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
|
||||||
|
template / tokenizer / image packing + state normalization, no mocks) must produce
|
||||||
|
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
|
||||||
|
as the original package's processor, given the identical raw observations
|
||||||
|
(images, state, language) recorded in the artifact. Artifacts written by older
|
||||||
|
versions of the dump script carry no raw observations; this case then SKIPS with
|
||||||
|
a regeneration hint.
|
||||||
|
|
||||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
|
||||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
present, or when no artifact has been generated. By default they look for artifacts in
|
||||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||||
for the full run procedure.
|
for the full run procedure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||||
|
|
||||||
|
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
|
||||||
SEED = 42
|
SEED = 42
|
||||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||||
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
|||||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||||
_ARTIFACT_SUFFIX = ".npz"
|
_ARTIFACT_SUFFIX = ".npz"
|
||||||
|
|
||||||
|
# Collated keys compared by the preprocessor parity case: integer/id tensors must
|
||||||
|
# match exactly; float tensors within ATOL/RTOL.
|
||||||
|
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
|
||||||
|
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
|
||||||
|
|
||||||
|
|
||||||
def _artifact_dir() -> Path:
|
def _artifact_dir() -> Path:
|
||||||
"""Directory holding the per-embodiment .npz artifacts.
|
"""Directory holding the per-embodiment .npz artifacts.
|
||||||
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
|
|||||||
return str(ckpt)
|
return str(ckpt)
|
||||||
|
|
||||||
|
|
||||||
def _load_artifact(path: Path):
|
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
|
||||||
|
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
|
||||||
data = np.load(path, allow_pickle=True)
|
data = np.load(path, allow_pickle=True)
|
||||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||||
|
if "seed" in data.files:
|
||||||
|
seed = int(data["seed"])
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
|
||||||
|
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
|
||||||
|
"regenerate the artifact with the current dump script.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
seed = SEED
|
||||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||||
inputs = {}
|
inputs = {}
|
||||||
for key in data.files:
|
for key in data.files:
|
||||||
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
|
|||||||
if "int" in declared or "long" in declared:
|
if "int" in declared or "long" in declared:
|
||||||
t = t.long()
|
t = t.long()
|
||||||
inputs[name] = t
|
inputs[name] = t
|
||||||
return original_action, inputs
|
return original_action, inputs, seed
|
||||||
|
|
||||||
|
|
||||||
|
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
|
||||||
|
"""Return the raw observation recorded in the artifact, or None for old artifacts.
|
||||||
|
|
||||||
|
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
|
||||||
|
exact raw observation the producer fed to the original processor: per-camera uint8
|
||||||
|
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
|
||||||
|
(``raw::state.<key>``, (B, T, dim)) and the language instruction
|
||||||
|
(``raw::language``, one string per batch element). ``raw_video_keys`` /
|
||||||
|
``raw_state_keys`` record the checkpoint modality-key order.
|
||||||
|
"""
|
||||||
|
data = np.load(path, allow_pickle=True)
|
||||||
|
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
|
||||||
|
if any(marker not in data.files for marker in markers):
|
||||||
|
return None
|
||||||
|
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
|
||||||
|
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
|
||||||
|
return {
|
||||||
|
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
|
||||||
|
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
|
||||||
|
"language": [str(t) for t in data["raw::language"].tolist()],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Convert the producer's raw observation into a LeRobot policy batch."""
|
||||||
|
batch: dict[str, Any] = {}
|
||||||
|
for key, frames in raw["video"].items():
|
||||||
|
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
|
||||||
|
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
|
||||||
|
# observation.state is the per-key state vectors (latest frame) concatenated in
|
||||||
|
# checkpoint modality-key order -- the layout the LeRobot pack step and the
|
||||||
|
# flattened checkpoint statistics expect.
|
||||||
|
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
|
||||||
|
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
|
||||||
|
batch["task"] = list(raw["language"])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||||
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
|||||||
return nested.get("inputs", nested)
|
return nested.get("inputs", nested)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_collated_parity(
|
||||||
|
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
|
||||||
|
) -> None:
|
||||||
|
"""Compare one collated tensor produced by LeRobot against the original's."""
|
||||||
|
assert isinstance(lerobot_value, torch.Tensor), (
|
||||||
|
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
|
||||||
|
f"{type(lerobot_value).__name__}, expected a tensor."
|
||||||
|
)
|
||||||
|
lerobot_t = lerobot_value.detach().cpu()
|
||||||
|
original_t = original_value.detach().cpu()
|
||||||
|
assert lerobot_t.shape == original_t.shape, (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
|
||||||
|
f"original={tuple(original_t.shape)}."
|
||||||
|
)
|
||||||
|
if exact:
|
||||||
|
mismatched = int((lerobot_t.long() != original_t.long()).sum())
|
||||||
|
assert mismatched == 0, (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
|
||||||
|
f"{mismatched}/{original_t.numel()} elements mismatch."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lerobot_f, original_f = lerobot_t.float(), original_t.float()
|
||||||
|
max_diff = (lerobot_f - original_f).abs().max().item()
|
||||||
|
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
|
||||||
|
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
|
||||||
|
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lerobot_model():
|
def lerobot_model():
|
||||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||||
@@ -165,8 +256,7 @@ def lerobot_model():
|
|||||||
|
|
||||||
_ARTIFACTS = _discover_artifacts()
|
_ARTIFACTS = _discover_artifacts()
|
||||||
|
|
||||||
|
_requires_artifacts = pytest.mark.skipif(
|
||||||
@pytest.mark.skipif(
|
|
||||||
not _ARTIFACTS,
|
not _ARTIFACTS,
|
||||||
reason=(
|
reason=(
|
||||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||||
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
|
|||||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_requires_artifacts
|
||||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||||
original_action, flat_inputs = _load_artifact(artifact)
|
original_action, flat_inputs, seed = _load_artifact(artifact)
|
||||||
model_inputs = _unflatten(flat_inputs)
|
model_inputs = _unflatten(flat_inputs)
|
||||||
|
|
||||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||||
torch.manual_seed(SEED)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(SEED)
|
torch.cuda.manual_seed_all(seed)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
out = lerobot_model.get_action(model_inputs)
|
out = lerobot_model.get_action(model_inputs)
|
||||||
lerobot_action = out["action_pred"].float().cpu()
|
lerobot_action = out["action_pred"].float().cpu()
|
||||||
|
|
||||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
assert lerobot_action.shape == original_action.shape, (
|
||||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
|
||||||
original_action = original_action[:, :t, :d]
|
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
|
||||||
lerobot_action = lerobot_action[:, :t, :d]
|
"The same checkpoint and inputs must produce identical shapes; this indicates an "
|
||||||
|
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
|
||||||
|
"utils/dump_original_n1_7.py)."
|
||||||
|
)
|
||||||
|
|
||||||
diff = torch.abs(lerobot_action - original_action)
|
diff = torch.abs(lerobot_action - original_action)
|
||||||
max_diff = diff.max().item()
|
max_diff = diff.max().item()
|
||||||
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
|||||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_requires_artifacts
|
||||||
|
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||||
|
def test_groot_preprocessor_parity(embodiment_tag, artifact):
|
||||||
|
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
|
||||||
|
|
||||||
|
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
|
||||||
|
template, tokenizer and image packing plus the checkpoint-driven state
|
||||||
|
normalization (no mocks) -- on the raw observations recorded in the artifact, and
|
||||||
|
compares every collated model input against the ones the original ``gr00t``
|
||||||
|
processor produced from the same raw observations.
|
||||||
|
"""
|
||||||
|
raw = _load_raw_observation(artifact)
|
||||||
|
if raw is None:
|
||||||
|
pytest.skip(
|
||||||
|
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
|
||||||
|
"not record raw observations; regenerate it with the current dump script to run the "
|
||||||
|
"preprocessor parity case."
|
||||||
|
)
|
||||||
|
_, flat_inputs, _ = _load_artifact(artifact)
|
||||||
|
original_inputs = _unflatten(flat_inputs)
|
||||||
|
|
||||||
|
ckpt = _resolve_checkpoint()
|
||||||
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
|
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
|
||||||
|
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
|
||||||
|
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||||
|
|
||||||
|
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
|
||||||
|
|
||||||
|
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
|
||||||
|
missing_original = [k for k in compared_keys if k not in original_inputs]
|
||||||
|
missing_lerobot = [k for k in compared_keys if k not in processed]
|
||||||
|
assert not missing_original, (
|
||||||
|
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
|
||||||
|
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
|
||||||
|
)
|
||||||
|
assert not missing_lerobot, (
|
||||||
|
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
|
||||||
|
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in compared_keys:
|
||||||
|
_assert_collated_parity(
|
||||||
|
embodiment_tag,
|
||||||
|
name,
|
||||||
|
processed[name],
|
||||||
|
original_inputs[name],
|
||||||
|
exact=name in _COLLATED_EXACT_KEYS,
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
|
|||||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||||
|
|
||||||
|
* the RAW observation fed to the original processor (per-camera uint8 frames,
|
||||||
|
per-key state vectors, the language instruction), so the LeRobot side can also
|
||||||
|
run its OWN preprocessor on identical raw inputs and compare collated tensors,
|
||||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||||
* the random seed used right before the flow-matching sampler,
|
* the random seed used right before the flow-matching sampler,
|
||||||
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
|
|||||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||||
``libero_sim``.
|
``libero_sim``.
|
||||||
|
|
||||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
|
||||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
|
||||||
|
(model parity), and the raw observation is replayed through LeRobot's own
|
||||||
|
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||||
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
|||||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||||
# present-but-empty so the processor's state transform finds every expected key.
|
# present-but-empty so the processor's state transform finds every expected key.
|
||||||
state = {
|
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
|
||||||
for k, dim in state_spec
|
|
||||||
}
|
|
||||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||||
return {"video": video, "state": state, "language": language}
|
return {"video": video, "state": state, "language": language}
|
||||||
|
|
||||||
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
|||||||
lang_key = modality_cfg["language"].modality_keys[0]
|
lang_key = modality_cfg["language"].modality_keys[0]
|
||||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||||
|
|
||||||
|
# Snapshot the RAW observation exactly as fed to the original processor below. The
|
||||||
|
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
|
||||||
|
# and compares the resulting collated tensors against the "in::" ones saved further
|
||||||
|
# down. raw_state_keys records the checkpoint modality-key order, which is the
|
||||||
|
# concatenation order of the flat LeRobot ``observation.state`` vector.
|
||||||
|
spec_keys = [key for key, _ in state_spec]
|
||||||
|
state_modality = modality_cfg.get("state")
|
||||||
|
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
|
||||||
|
state_keys += [key for key in spec_keys if key not in state_keys]
|
||||||
|
raw_language = [
|
||||||
|
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
||||||
|
for item in observation["language"][lang_key]
|
||||||
|
]
|
||||||
|
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
|
||||||
|
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
|
||||||
|
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
|
||||||
|
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
|
||||||
|
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
|
||||||
|
|
||||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||||
policy.modality_configs = {
|
policy.modality_configs = {
|
||||||
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
|||||||
embodiment_tag=np.array(tag),
|
embodiment_tag=np.array(tag),
|
||||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||||
|
**raw_flat,
|
||||||
**flat,
|
**flat,
|
||||||
)
|
)
|
||||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||||
@@ -181,7 +203,12 @@ def main():
|
|||||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||||
try:
|
try:
|
||||||
dump_one_tag(
|
dump_one_tag(
|
||||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
policy,
|
||||||
|
fair_model,
|
||||||
|
tag,
|
||||||
|
all_modality[tag],
|
||||||
|
state_spec,
|
||||||
|
args,
|
||||||
out_dir / f"original_n1_7_{tag}.npz",
|
out_dir / f"original_n1_7_{tag}.npz",
|
||||||
)
|
)
|
||||||
done.append(tag)
|
done.append(tag)
|
||||||
|
|||||||
Reference in New Issue
Block a user