mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f81fc79564 | |||
| 9ce6633518 |
@@ -42,14 +42,6 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
|
|||||||
)
|
)
|
||||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||||
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
|
|
||||||
# falls back to these when a checkpoint ships no image sizing in its processor_config
|
|
||||||
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
|
|
||||||
# are resized to the expected resolution instead of being patchified at full camera
|
|
||||||
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
|
|
||||||
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
|
||||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
|
||||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||||
@@ -329,6 +321,9 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
|||||||
normalized = candidate.lower().replace("-", "_")
|
normalized = candidate.lower().replace("-", "_")
|
||||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||||
return GROOT_N1_7
|
return GROOT_N1_7
|
||||||
|
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||||
|
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||||
|
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||||
return GROOT_N1_5
|
return GROOT_N1_5
|
||||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||||
@@ -370,7 +365,11 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Groot-specific model parameters
|
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||||
|
# Kept only so config.json files saved with earlier versions still parse.
|
||||||
|
image_size: tuple[int, int] = (256, 256)
|
||||||
|
|
||||||
|
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||||
|
|
||||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||||
model_version: str = GROOT_N1_7
|
model_version: str = GROOT_N1_7
|
||||||
@@ -386,43 +385,14 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||||
|
|
||||||
|
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||||
|
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||||
|
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||||
|
tokenizer_assets_repo: str | None = None
|
||||||
|
|
||||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||||
embodiment_tag: str = "new_embodiment"
|
embodiment_tag: str = "new_embodiment"
|
||||||
|
|
||||||
# Inference-only override for the number of flow-matching denoising steps used to decode an
|
|
||||||
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
|
|
||||||
# inference speed for action quality; applied at base-model load via _create_groot_model.
|
|
||||||
num_inference_timesteps: int | None = None
|
|
||||||
|
|
||||||
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
|
|
||||||
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
|
|
||||||
execution_horizon: int | None = None
|
|
||||||
|
|
||||||
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
|
|
||||||
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
|
|
||||||
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
|
|
||||||
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
|
|
||||||
# base-model build (so it applies during lerobot-train, which uses __init__ not
|
|
||||||
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
|
|
||||||
warm_start_embodiment_slot: int | str | None = None
|
|
||||||
|
|
||||||
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
|
|
||||||
# When True, GR00T converts absolute->relative inside its own pack step (training) and
|
|
||||||
# reconstructs absolute inside its own flat decode step (inference), using a cached
|
|
||||||
# reference state. The dataset stays absolute; compute relative ACTION stats with
|
|
||||||
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
|
|
||||||
# "['gripper']"` (this only rewrites stats, not actions).
|
|
||||||
use_relative_actions: bool = False
|
|
||||||
|
|
||||||
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
|
|
||||||
# Case-insensitive token match against action_feature_names.
|
|
||||||
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
|
||||||
|
|
||||||
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
|
|
||||||
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
|
|
||||||
# identified and kept absolute. When None, the gripper cannot be identified.
|
|
||||||
action_feature_names: list[str] | None = None
|
|
||||||
|
|
||||||
# Fine-tuning control arguments
|
# Fine-tuning control arguments
|
||||||
|
|
||||||
# Whether to fine-tune the llm backbone
|
# Whether to fine-tune the llm backbone
|
||||||
@@ -458,13 +428,10 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
warmup_ratio: float = 0.05
|
warmup_ratio: float = 0.05
|
||||||
use_bf16: bool = True
|
use_bf16: bool = True
|
||||||
|
|
||||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
|
||||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||||
# would break every previously saved groot checkpoint at config-load time.
|
# would break every previously saved groot checkpoint at config-load time.
|
||||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
|
||||||
tokenizer_assets_repo: str | None = None
|
|
||||||
video_backend: str = "decord"
|
video_backend: str = "decord"
|
||||||
balance_dataset_weights: bool = True
|
balance_dataset_weights: bool = True
|
||||||
balance_trajectory_weights: bool = True
|
balance_trajectory_weights: bool = True
|
||||||
@@ -478,6 +445,9 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
resume: bool = False
|
resume: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||||
|
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||||
|
# N1.5 checkpoint or config is being loaded.
|
||||||
if self.tokenizer_assets_repo is not None:
|
if self.tokenizer_assets_repo is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||||
@@ -612,11 +582,22 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list[int]:
|
def action_delta_indices(self) -> list[int]:
|
||||||
"""Return indices for delta actions."""
|
"""Return indices for delta actions.
|
||||||
|
|
||||||
|
The model action horizon is read from the checkpoint's processor_config.json
|
||||||
|
when available; the result is cached (keyed on the inputs that determine it) so
|
||||||
|
repeated access during dataset/training setup does not re-read from disk.
|
||||||
|
"""
|
||||||
|
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||||
|
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||||
|
if cached is not None and cached[0] == cache_key:
|
||||||
|
return cached[1]
|
||||||
model_action_horizon = (
|
model_action_horizon = (
|
||||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||||
)
|
)
|
||||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||||
|
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||||
|
return indices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from torch.distributions import Beta
|
|||||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
|
||||||
|
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
@@ -72,13 +71,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
|||||||
"backbone_embedding_dim": 2048,
|
"backbone_embedding_dim": 2048,
|
||||||
"tune_llm": False,
|
"tune_llm": False,
|
||||||
"tune_visual": False,
|
"tune_visual": False,
|
||||||
"select_layer": 16,
|
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
|
||||||
"reproject_vision": False,
|
"reproject_vision": False,
|
||||||
"use_flash_attention": True,
|
"use_flash_attention": True,
|
||||||
"load_bf16": False,
|
"load_bf16": False,
|
||||||
"backbone_trainable_params_fp32": True,
|
"backbone_trainable_params_fp32": True,
|
||||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
"image_crop_size": (230, 230),
|
||||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
"image_target_size": (256, 256),
|
||||||
"shortest_image_edge": None,
|
"shortest_image_edge": None,
|
||||||
"crop_fraction": None,
|
"crop_fraction": None,
|
||||||
"random_rotation_angle": None,
|
"random_rotation_angle": None,
|
||||||
@@ -823,6 +822,8 @@ def get_backbone_cls(config: GR00TN17Config):
|
|||||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||||
return Qwen3Backbone
|
return Qwen3Backbone
|
||||||
if config.backbone_model_type == "qwen":
|
if config.backbone_model_type == "qwen":
|
||||||
|
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
|
||||||
|
# marker, so trust the explicit backbone type but surface what is being assumed.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||||
"backbone because backbone_model_type='qwen'.",
|
"backbone because backbone_model_type='qwen'.",
|
||||||
@@ -913,6 +914,10 @@ class GR00TN17(PreTrainedModel):
|
|||||||
"trust_remote_code": True
|
"trust_remote_code": True
|
||||||
}
|
}
|
||||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||||
|
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
|
||||||
|
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
|
||||||
|
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
|
||||||
|
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
|
||||||
for key in ("cache_dir", "local_files_only", "token"):
|
for key in ("cache_dir", "local_files_only", "token"):
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||||
|
|||||||
@@ -54,98 +54,6 @@ logger = logging.getLogger(__name__)
|
|||||||
T = TypeVar("T", bound="GrootPolicy")
|
T = TypeVar("T", bound="GrootPolicy")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_embodiment_id(value: int | str) -> int:
|
|
||||||
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
|
|
||||||
|
|
||||||
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
|
|
||||||
Raises ValueError listing the known keys if the name is unknown.
|
|
||||||
"""
|
|
||||||
from .processor_groot import N1_7_EMBODIMENT_MAPPING
|
|
||||||
|
|
||||||
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
|
|
||||||
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
|
|
||||||
if isinstance(value, int):
|
|
||||||
return value
|
|
||||||
if value in N1_7_EMBODIMENT_MAPPING:
|
|
||||||
return N1_7_EMBODIMENT_MAPPING[value]
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
|
|
||||||
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
|
|
||||||
"""Copy category-specific action-head weights from one embodiment slot to another.
|
|
||||||
|
|
||||||
Used at base-model load (training only) to warm-start a cold target embodiment slot
|
|
||||||
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
|
|
||||||
parameters across every CategorySpecificLinear in the action head's state encoder,
|
|
||||||
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
|
|
||||||
of range or identical.
|
|
||||||
"""
|
|
||||||
if source_id == target_id:
|
|
||||||
logger.warning(
|
|
||||||
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
|
|
||||||
"skipping (nothing to copy).",
|
|
||||||
source_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
action_head = getattr(model, "action_head", None)
|
|
||||||
if action_head is None:
|
|
||||||
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
|
|
||||||
linear_groups = [
|
|
||||||
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
|
|
||||||
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
|
|
||||||
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
copied: list[str] = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for submodule, attr_names in linear_groups:
|
|
||||||
if submodule is None:
|
|
||||||
continue
|
|
||||||
submodule_name = type(submodule).__name__
|
|
||||||
for attr_name in attr_names:
|
|
||||||
lin = getattr(submodule, attr_name, None)
|
|
||||||
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
|
|
||||||
continue
|
|
||||||
num_categories = lin.W.shape[0]
|
|
||||||
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
|
|
||||||
logger.warning(
|
|
||||||
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
|
|
||||||
"for %s.%s (num_categories=%d); skipping this layer.",
|
|
||||||
source_id,
|
|
||||||
target_id,
|
|
||||||
submodule_name,
|
|
||||||
attr_name,
|
|
||||||
num_categories,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
lin.W.data[target_id] = lin.W.data[source_id].clone()
|
|
||||||
lin.b.data[target_id] = lin.b.data[source_id].clone()
|
|
||||||
copied.append(f"{submodule_name}.{attr_name}")
|
|
||||||
|
|
||||||
if copied:
|
|
||||||
logger.info(
|
|
||||||
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
|
|
||||||
"to slot %d for: %s.",
|
|
||||||
source_id,
|
|
||||||
target_id,
|
|
||||||
", ".join(copied),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
|
|
||||||
"(source_id=%d, target_id=%d).",
|
|
||||||
source_id,
|
|
||||||
target_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GrootPolicy(PreTrainedPolicy):
|
class GrootPolicy(PreTrainedPolicy):
|
||||||
"""Wrapper around external Groot model for LeRobot integration."""
|
"""Wrapper around external Groot model for LeRobot integration."""
|
||||||
|
|
||||||
@@ -185,24 +93,11 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
transformers_loading_kwargs={"trust_remote_code": True},
|
transformers_loading_kwargs={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inference-only override for the number of flow-matching denoising steps. The action
|
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||||
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
|
# bf16 preference when it is enabled instead of reading a default back.
|
||||||
# t schedule adapt automatically.
|
if self.config.use_bf16:
|
||||||
if self.config.num_inference_timesteps is not None:
|
model.compute_dtype = "bfloat16"
|
||||||
n = int(self.config.num_inference_timesteps)
|
model.config.compute_dtype = "bfloat16"
|
||||||
model.config.num_inference_timesteps = n
|
|
||||||
model.action_head.num_inference_timesteps = n
|
|
||||||
|
|
||||||
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
|
|
||||||
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
|
|
||||||
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
|
|
||||||
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
|
|
||||||
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
|
|
||||||
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
|
|
||||||
if self.config.warm_start_embodiment_slot is not None:
|
|
||||||
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
|
|
||||||
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
|
|
||||||
_warm_start_embodiment_slot(model, source_id, target_id)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -371,11 +266,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
horizons.append(checkpoint_action_horizon)
|
horizons.append(checkpoint_action_horizon)
|
||||||
if execution_horizon is not None:
|
if execution_horizon is not None:
|
||||||
horizons.append(execution_horizon)
|
horizons.append(execution_horizon)
|
||||||
# An explicit config override caps the open-loop horizon (inference cadence), overriding
|
return min(horizons)
|
||||||
# the value inferred from the checkpoint/embodiment.
|
|
||||||
if self.config.execution_horizon is not None:
|
|
||||||
horizons.append(max(1, int(self.config.execution_horizon)))
|
|
||||||
return max(1, min(horizons))
|
|
||||||
|
|
||||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||||
@@ -543,16 +434,6 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
|
|
||||||
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
|
|
||||||
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
|
|
||||||
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
|
|
||||||
from .processor_groot import _GROOT_REF_HOLDER_KEY
|
|
||||||
|
|
||||||
holder = batch.get(_GROOT_REF_HOLDER_KEY)
|
|
||||||
if holder is not None:
|
|
||||||
holder.freeze()
|
|
||||||
|
|
||||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||||
# During inference, we do not pass action because it is predicted.
|
# During inference, we do not pass action because it is predicted.
|
||||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||||
|
|||||||
@@ -23,10 +23,9 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.v2.functional as tv_functional
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import InterpolationMode
|
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -47,8 +46,6 @@ from lerobot.processor import (
|
|||||||
RenameObservationsProcessorStep,
|
RenameObservationsProcessorStep,
|
||||||
batch_to_transition,
|
batch_to_transition,
|
||||||
policy_action_to_transition,
|
policy_action_to_transition,
|
||||||
to_absolute_actions,
|
|
||||||
to_relative_actions,
|
|
||||||
transition_to_batch,
|
transition_to_batch,
|
||||||
transition_to_policy_action,
|
transition_to_policy_action,
|
||||||
)
|
)
|
||||||
@@ -61,14 +58,11 @@ from lerobot.utils.constants import (
|
|||||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
)
|
)
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device
|
|
||||||
|
|
||||||
from .configuration_groot import (
|
from .configuration_groot import (
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||||
GROOT_N1_7_BACKBONE_MODEL,
|
GROOT_N1_7_BACKBONE_MODEL,
|
||||||
N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
|
||||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
|
||||||
GrootConfig,
|
GrootConfig,
|
||||||
is_raw_groot_n1_7_checkpoint,
|
is_raw_groot_n1_7_checkpoint,
|
||||||
)
|
)
|
||||||
@@ -90,30 +84,6 @@ N1_7_EMBODIMENT_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
_GROOT_REF_HOLDER_KEY = "_groot_relative_ref_holder" # private; dropped by _filter_groot_inputs, never reaches the model
|
|
||||||
|
|
||||||
|
|
||||||
class _GrootRelativeRefHolder:
|
|
||||||
"""Runtime-only carrier shared (by object identity) between the pack step (owner/writer of the
|
|
||||||
live reference), GrootPolicy.predict_action_chunk (freezes it at a real predict event), and the
|
|
||||||
decode step (reads the frozen reference). Not serialized. One instance per pack step."""
|
|
||||||
|
|
||||||
__slots__ = ("reference_state", "raw_state", "frozen_reference", "frozen_raw")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.reference_state = None
|
|
||||||
self.raw_state = None
|
|
||||||
self.frozen_reference = None
|
|
||||||
self.frozen_raw = None
|
|
||||||
|
|
||||||
def freeze(self) -> None:
|
|
||||||
self.frozen_reference = self.reference_state
|
|
||||||
self.frozen_raw = self.raw_state
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.reference_state = self.raw_state = self.frozen_reference = self.frozen_raw = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _GrootN17CheckpointProcessorAssets:
|
class _GrootN17CheckpointProcessorAssets:
|
||||||
"""Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint.
|
"""Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint.
|
||||||
@@ -143,39 +113,6 @@ class _GrootN17CheckpointProcessorAssets:
|
|||||||
use_albumentations: bool
|
use_albumentations: bool
|
||||||
|
|
||||||
|
|
||||||
def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None:
|
|
||||||
"""Resolve a base model path to a local snapshot dir holding its sidecar JSONs.
|
|
||||||
|
|
||||||
``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a
|
|
||||||
bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7
|
|
||||||
checkpoint and the processor would fall back to LeRobot default image geometry instead of the
|
|
||||||
checkpoint's processor_config.json geometry. When the path is not already a local dir, this
|
|
||||||
downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure
|
|
||||||
returns the original string unchanged. Only used on the fresh-build (training) path; inference
|
|
||||||
loads the serialized processor, so no per-inference network call is added.
|
|
||||||
"""
|
|
||||||
if base_model_path is None:
|
|
||||||
return None
|
|
||||||
if Path(base_model_path).expanduser().is_dir():
|
|
||||||
return base_model_path
|
|
||||||
try:
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
local_dir = snapshot_download(
|
|
||||||
base_model_path,
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*.json"],
|
|
||||||
)
|
|
||||||
logging.debug(
|
|
||||||
"Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.",
|
|
||||||
base_model_path,
|
|
||||||
local_dir,
|
|
||||||
)
|
|
||||||
return local_dir
|
|
||||||
except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure)
|
|
||||||
return base_model_path
|
|
||||||
|
|
||||||
|
|
||||||
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
|
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
|
||||||
"""Load N1.7 processor settings from checkpoint sidecar JSON files.
|
"""Load N1.7 processor settings from checkpoint sidecar JSON files.
|
||||||
|
|
||||||
@@ -183,11 +120,10 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
|||||||
can keep using caller-provided dataset stats and config values.
|
can keep using caller-provided dataset stats and config values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path)
|
if not is_raw_groot_n1_7_checkpoint(config.base_model_path):
|
||||||
if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
checkpoint_path = Path(resolved_base_model_path).expanduser()
|
checkpoint_path = Path(config.base_model_path).expanduser()
|
||||||
processor_config = _read_json(checkpoint_path / "processor_config.json")
|
processor_config = _read_json(checkpoint_path / "processor_config.json")
|
||||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||||
if not isinstance(processor_kwargs, dict):
|
if not isinstance(processor_kwargs, dict):
|
||||||
@@ -512,74 +448,60 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
|||||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||||
|
|
||||||
|
|
||||||
def _build_relative_action_mask(
|
def _legacy_groot_processor_overrides(
|
||||||
action_dim: int,
|
config: GrootConfig,
|
||||||
exclude_joints: list[str] | None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
|
||||||
action_names: list[str] | None,
|
preprocessor_overrides: dict[str, Any] | None = None,
|
||||||
) -> list[bool]:
|
postprocessor_overrides: dict[str, Any] | None = None,
|
||||||
"""Build the per-dim relative-action mask (True = convert to relative, False = keep absolute).
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
|
"""Patch older serialized Groot processors with fields current processors expect."""
|
||||||
|
|
||||||
Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded
|
preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||||
(kept absolute) by case-insensitive token match against ``action_names``.
|
postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||||
|
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
|
||||||
|
|
||||||
When ``action_names`` is None we cannot identify the gripper, so this returns all-True
|
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
||||||
(every dim treated as relative). The user should ensure ``config.action_feature_names`` is
|
pack_input_overrides["normalize_min_max"] = True
|
||||||
populated (the factory does this from dataset meta) so the gripper can be kept absolute;
|
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
||||||
arm-relative still works either way, but a missing-name gripper would be treated as relative.
|
|
||||||
|
try:
|
||||||
|
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||||
|
except Exception:
|
||||||
|
env_action_dim = 0
|
||||||
|
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v2", {}))
|
||||||
|
action_unpack_overrides["normalize_min_max"] = True
|
||||||
|
action_unpack_overrides["env_action_dim"] = env_action_dim
|
||||||
|
postprocessor_overrides["groot_action_unpack_unnormalize_v2"] = action_unpack_overrides
|
||||||
|
|
||||||
|
return preprocessor_overrides, postprocessor_overrides
|
||||||
|
|
||||||
|
|
||||||
|
def _pretrained_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
||||||
|
"""Check whether a serialized processor pipeline contains a registry step.
|
||||||
|
|
||||||
|
Resolves the processor config from a local directory or, for Hub repo ids,
|
||||||
|
via ``hf_hub_download`` (which serves the cached copy when offline). Returns
|
||||||
|
False when the config cannot be resolved; loading then proceeds with the
|
||||||
|
legacy overrides and `make_groot_pre_post_processors_from_pretrained` retries
|
||||||
|
without them if they do not match the serialized pipeline.
|
||||||
"""
|
"""
|
||||||
if not exclude_joints or action_names is None:
|
path = Path(pretrained_path).expanduser()
|
||||||
return [True] * action_dim
|
if path.is_dir():
|
||||||
|
config = _read_json(path / config_filename)
|
||||||
exclude_tokens = [str(name).lower() for name in exclude_joints if name]
|
elif path.exists():
|
||||||
if not exclude_tokens:
|
return False
|
||||||
return [True] * action_dim
|
else:
|
||||||
|
try:
|
||||||
mask: list[bool] = []
|
config_path = hf_hub_download(
|
||||||
for name in action_names[:action_dim]:
|
repo_id=str(pretrained_path), filename=config_filename, repo_type="model"
|
||||||
action_name = str(name).lower()
|
|
||||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
|
||||||
mask.append(not is_excluded)
|
|
||||||
|
|
||||||
if len(mask) < action_dim:
|
|
||||||
mask.extend([True] * (action_dim - len(mask)))
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
|
||||||
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
|
||||||
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
|
||||||
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
|
|
||||||
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
|
|
||||||
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
|
|
||||||
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
|
|
||||||
|
|
||||||
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
|
|
||||||
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
|
|
||||||
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
|
|
||||||
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
|
|
||||||
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
|
|
||||||
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
|
|
||||||
(e.g. a typo) is left in place and still raises.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not overrides:
|
|
||||||
return overrides
|
|
||||||
|
|
||||||
filtered: dict[str, Any] = {}
|
|
||||||
for key, value in overrides.items():
|
|
||||||
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
|
|
||||||
logging.debug(
|
|
||||||
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
|
|
||||||
"no matching step (see GrootConfig.normalization_mapping).",
|
|
||||||
key,
|
|
||||||
)
|
)
|
||||||
continue
|
except Exception:
|
||||||
filtered[key] = value
|
return False
|
||||||
return filtered
|
config = _read_json(Path(config_path))
|
||||||
|
steps = config.get("steps", [])
|
||||||
|
if not isinstance(steps, list):
|
||||||
|
return False
|
||||||
|
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
||||||
|
|
||||||
|
|
||||||
def _apply_groot_step_overrides(
|
def _apply_groot_step_overrides(
|
||||||
@@ -595,8 +517,7 @@ def _apply_groot_step_overrides(
|
|||||||
steps by registry name only — prefer registry names so overrides keep
|
steps by registry name only — prefer registry names so overrides keep
|
||||||
working after the checkpoint is converted and reloaded from a serialized
|
working after the checkpoint is converted and reloaded from a serialized
|
||||||
pipeline). Keys or fields that match nothing raise instead of being dropped
|
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||||
silently (standard normalization keys GR00T has no step for are removed
|
silently.
|
||||||
beforehand by ``_drop_groot_absent_standard_overrides``).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not overrides:
|
if not overrides:
|
||||||
@@ -652,13 +573,7 @@ def make_groot_pre_post_processors_from_pretrained(
|
|||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
]:
|
]:
|
||||||
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
|
"""Load Groot processors while preserving compatibility with older serialized configs."""
|
||||||
|
|
||||||
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
|
|
||||||
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
|
|
||||||
# paths raise. This must happen before either branch below.
|
|
||||||
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
|
|
||||||
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
|
|
||||||
|
|
||||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||||
processor_cfg = copy(config)
|
processor_cfg = copy(config)
|
||||||
@@ -674,13 +589,49 @@ def make_groot_pre_post_processors_from_pretrained(
|
|||||||
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
|
|
||||||
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
caller_preprocessor_overrides = dict(preprocessor_overrides or {})
|
||||||
|
caller_postprocessor_overrides = dict(postprocessor_overrides or {})
|
||||||
|
if _pretrained_processor_config_has_step(
|
||||||
pretrained_path,
|
pretrained_path,
|
||||||
preprocessor_overrides=preprocessor_overrides,
|
postprocessor_config_filename,
|
||||||
postprocessor_overrides=postprocessor_overrides,
|
"groot_n1_7_action_decode_v1",
|
||||||
preprocessor_config_filename=preprocessor_config_filename,
|
):
|
||||||
postprocessor_config_filename=postprocessor_config_filename,
|
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
||||||
)
|
# action decoder. Adding the legacy action-unpack override would target
|
||||||
|
# a step that is not present and break loading.
|
||||||
|
applied_legacy_overrides = False
|
||||||
|
preprocessor_overrides = caller_preprocessor_overrides
|
||||||
|
postprocessor_overrides = caller_postprocessor_overrides
|
||||||
|
else:
|
||||||
|
applied_legacy_overrides = True
|
||||||
|
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
||||||
|
config=config,
|
||||||
|
dataset_stats=dataset_stats,
|
||||||
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
|
postprocessor_overrides=postprocessor_overrides,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||||
|
pretrained_path,
|
||||||
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
|
postprocessor_overrides=postprocessor_overrides,
|
||||||
|
preprocessor_config_filename=preprocessor_config_filename,
|
||||||
|
postprocessor_config_filename=postprocessor_config_filename,
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
if not applied_legacy_overrides:
|
||||||
|
raise
|
||||||
|
# The legacy overrides target steps that are absent from the serialized
|
||||||
|
# pipelines (e.g. a converted raw N1.7 checkpoint whose postprocessor
|
||||||
|
# config could not be inspected before loading); retry with the caller
|
||||||
|
# overrides only.
|
||||||
|
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||||
|
pretrained_path,
|
||||||
|
preprocessor_overrides=caller_preprocessor_overrides,
|
||||||
|
postprocessor_overrides=caller_postprocessor_overrides,
|
||||||
|
preprocessor_config_filename=preprocessor_config_filename,
|
||||||
|
postprocessor_config_filename=postprocessor_config_filename,
|
||||||
|
)
|
||||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
@@ -747,15 +698,8 @@ def _reconnect_groot_n1_7_pack_decode_steps(
|
|||||||
if pack_step is None:
|
if pack_step is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link:
|
|
||||||
# GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path
|
|
||||||
# (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after
|
|
||||||
# deserialization.
|
|
||||||
for step in postprocessor.steps:
|
for step in postprocessor.steps:
|
||||||
if (
|
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
|
||||||
isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep))
|
|
||||||
and step.pack_step is None
|
|
||||||
):
|
|
||||||
step.pack_step = pack_step
|
step.pack_step = pack_step
|
||||||
|
|
||||||
|
|
||||||
@@ -833,45 +777,23 @@ def make_groot_pre_post_processors(
|
|||||||
video_modality_keys=video_modality_keys,
|
video_modality_keys=video_modality_keys,
|
||||||
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
||||||
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
||||||
use_relative_actions=config.use_relative_actions,
|
|
||||||
relative_exclude_joints=config.relative_exclude_joints,
|
|
||||||
action_feature_names=config.action_feature_names,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
|
|
||||||
# when it provides an image_target_size; otherwise fall back to the geometry the
|
|
||||||
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
|
|
||||||
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
|
|
||||||
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
|
|
||||||
# frames, inflating the VLM token count -- slowing both dataloading_s and update_s --
|
|
||||||
# and feeding the model a resolution it was not trained on.
|
|
||||||
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
|
|
||||||
image_target_size = checkpoint_assets.image_target_size
|
|
||||||
image_crop_size = checkpoint_assets.image_crop_size
|
|
||||||
shortest_image_edge = checkpoint_assets.shortest_image_edge
|
|
||||||
crop_fraction = checkpoint_assets.crop_fraction
|
|
||||||
else:
|
|
||||||
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
|
|
||||||
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
|
|
||||||
shortest_image_edge = None
|
|
||||||
crop_fraction = None
|
|
||||||
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
|
|
||||||
|
|
||||||
input_steps: list[ProcessorStep] = [
|
input_steps: list[ProcessorStep] = [
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
pack_step,
|
pack_step,
|
||||||
GrootN17VLMEncodeStep(
|
GrootN17VLMEncodeStep(
|
||||||
model_name=config.n1_7_backbone_model,
|
model_name=config.n1_7_backbone_model,
|
||||||
image_crop_size=image_crop_size,
|
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
|
||||||
image_target_size=image_target_size,
|
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
|
||||||
shortest_image_edge=shortest_image_edge,
|
shortest_image_edge=checkpoint_assets.shortest_image_edge
|
||||||
crop_fraction=crop_fraction,
|
if checkpoint_assets is not None
|
||||||
use_albumentations=use_albumentations,
|
else None,
|
||||||
# Run the image resize/normalize/patchify on the training device when
|
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
|
||||||
# possible instead of the single CPU main-loop thread (the dominant
|
use_albumentations=checkpoint_assets.use_albumentations
|
||||||
# cost folded into dataloading_s).
|
if checkpoint_assets is not None
|
||||||
device=config.device,
|
else False,
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
]
|
]
|
||||||
@@ -895,10 +817,6 @@ def make_groot_pre_post_processors(
|
|||||||
stats=padded_stats,
|
stats=padded_stats,
|
||||||
normalize_min_max=True,
|
normalize_min_max=True,
|
||||||
clip_normalized_action=True,
|
clip_normalized_action=True,
|
||||||
use_relative_actions=config.use_relative_actions,
|
|
||||||
relative_exclude_joints=config.relative_exclude_joints,
|
|
||||||
action_feature_names=config.action_feature_names,
|
|
||||||
pack_step=pack_step,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
action_decode_step = GrootN17ActionDecodeStep(
|
action_decode_step = GrootN17ActionDecodeStep(
|
||||||
@@ -1114,61 +1032,6 @@ def _transform_n1_7_image_for_vlm(
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def _transform_n1_7_image_for_vlm_torch(
|
|
||||||
image: torch.Tensor,
|
|
||||||
*,
|
|
||||||
image_crop_size: list[int] | None,
|
|
||||||
image_target_size: list[int] | None,
|
|
||||||
shortest_image_edge: int | None,
|
|
||||||
crop_fraction: float | None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Torch/torchvision port of the non-albumentations branch of
|
|
||||||
:func:`_transform_n1_7_image_for_vlm`.
|
|
||||||
|
|
||||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
|
||||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
|
||||||
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
|
||||||
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
|
||||||
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
|
|
||||||
"""
|
|
||||||
if image_target_size is None:
|
|
||||||
return image
|
|
||||||
|
|
||||||
target_h, target_w = image_target_size
|
|
||||||
_, height, width = image.shape
|
|
||||||
|
|
||||||
square_edge = max(height, width)
|
|
||||||
if height != width:
|
|
||||||
left = (square_edge - width) // 2
|
|
||||||
top = (square_edge - height) // 2
|
|
||||||
image = tv_functional.pad(
|
|
||||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
|
||||||
)
|
|
||||||
|
|
||||||
resize_edge = shortest_image_edge or target_h
|
|
||||||
image = tv_functional.resize(
|
|
||||||
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if crop_fraction is None and image_crop_size is not None:
|
|
||||||
crop_fraction = image_crop_size[0] / float(target_h)
|
|
||||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
|
||||||
# Match the PIL helper's center crop exactly: round() the crop size but
|
|
||||||
# floor() the offset (torchvision.center_crop rounds the offset, which
|
|
||||||
# shifts the region by 1px when (edge - crop) is odd).
|
|
||||||
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
|
|
||||||
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
|
|
||||||
top = max(0, (image.shape[-2] - crop_h) // 2)
|
|
||||||
left = max(0, (image.shape[-1] - crop_w) // 2)
|
|
||||||
image = image[..., top : top + crop_h, left : left + crop_w]
|
|
||||||
|
|
||||||
if tuple(image.shape[-2:]) != (target_h, target_w):
|
|
||||||
image = tv_functional.resize(
|
|
||||||
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
|
|
||||||
)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
|
@ProcessorStepRegistry.register(name="groot_n1_7_pack_inputs_v1")
|
||||||
class GrootN17PackInputsStep(ProcessorStep):
|
class GrootN17PackInputsStep(ProcessorStep):
|
||||||
@@ -1195,18 +1058,11 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
video_modality_keys: list[str] | None = None
|
video_modality_keys: list[str] | None = None
|
||||||
raw_stats: dict[str, Any] | None = None
|
raw_stats: dict[str, Any] | None = None
|
||||||
modality_config: dict[str, Any] | None = None
|
modality_config: dict[str, Any] | None = None
|
||||||
# Opt-in relative-action support: convert absolute->relative actions inside this pack step
|
# Unused: kept so serialized configs that include it still load. The raw
|
||||||
# (training) using the cached raw reference state, keeping excluded joints (e.g. gripper)
|
# state cache is per instance (_last_raw_state), never process-global.
|
||||||
# absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode.
|
state_cache_key: str = ""
|
||||||
use_relative_actions: bool = False
|
|
||||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
|
||||||
action_feature_names: list[str] | None = None
|
|
||||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||||
_last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
|
||||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||||
_ref_holder: "_GrootRelativeRefHolder" = field(
|
|
||||||
default_factory=_GrootRelativeRefHolder, init=False, repr=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
||||||
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
||||||
@@ -1328,7 +1184,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
start_idx += dim
|
start_idx += dim
|
||||||
if grouped:
|
if grouped:
|
||||||
self._last_raw_state = grouped
|
self._last_raw_state = grouped
|
||||||
self._ref_holder.raw_state = grouped
|
|
||||||
|
|
||||||
img_keys = self._ordered_image_keys(obs)
|
img_keys = self._ordered_image_keys(obs)
|
||||||
if img_keys:
|
if img_keys:
|
||||||
@@ -1348,9 +1203,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
formalize_language=self.formalize_language,
|
formalize_language=self.formalize_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached
|
|
||||||
# regardless of whether an action is present so inference caches it too for decode.
|
|
||||||
relative_reference_state: torch.Tensor | None = None
|
|
||||||
if OBS_STATE in obs:
|
if OBS_STATE in obs:
|
||||||
state = obs[OBS_STATE]
|
state = obs[OBS_STATE]
|
||||||
if state.dim() != 2:
|
if state.dim() != 2:
|
||||||
@@ -1359,10 +1211,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
if dim > self.max_state_dim:
|
if dim > self.max_state_dim:
|
||||||
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
|
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
|
||||||
_cache_raw_state(state)
|
_cache_raw_state(state)
|
||||||
if self.use_relative_actions:
|
|
||||||
relative_reference_state = state.detach().clone()
|
|
||||||
self._last_reference_state = relative_reference_state
|
|
||||||
self._ref_holder.reference_state = relative_reference_state
|
|
||||||
if self.normalize_min_max:
|
if self.normalize_min_max:
|
||||||
state = _min_max_norm(state, OBS_STATE)
|
state = _min_max_norm(state, OBS_STATE)
|
||||||
state = state.unsqueeze(1)
|
state = state.unsqueeze(1)
|
||||||
@@ -1385,19 +1233,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
|
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
|
||||||
if dim > self.max_action_dim:
|
if dim > self.max_action_dim:
|
||||||
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
|
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
|
||||||
# Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g.
|
|
||||||
# gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T.
|
|
||||||
if self.use_relative_actions:
|
|
||||||
if relative_reference_state is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"GrootN17PackInputsStep.use_relative_actions requires observation.state "
|
|
||||||
"(OBS_STATE) to be present alongside the action to build the relative "
|
|
||||||
"reference, but no state was found in this transition."
|
|
||||||
)
|
|
||||||
mask = _build_relative_action_mask(
|
|
||||||
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
|
|
||||||
)
|
|
||||||
action = to_relative_actions(action, relative_reference_state, mask)
|
|
||||||
if self.normalize_min_max:
|
if self.normalize_min_max:
|
||||||
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
|
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
|
||||||
action = flat.view(bsz, horizon, dim)
|
action = flat.view(bsz, horizon, dim)
|
||||||
@@ -1437,12 +1272,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
comp["action_mask"] = action_mask
|
comp["action_mask"] = action_mask
|
||||||
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device)
|
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
# Publish the runtime-only reference holder so the policy can freeze it at the predict
|
|
||||||
# event and the decode step can read the frozen reference. It rides in COMPLEMENTARY_DATA,
|
|
||||||
# survives the VLM-encode step and DeviceProcessorStep as a non-tensor, and reaches the
|
|
||||||
# policy via the batch (by object identity) through the pipeline's shallow copies.
|
|
||||||
comp[_GROOT_REF_HOLDER_KEY] = self._ref_holder
|
|
||||||
|
|
||||||
transition[TransitionKey.OBSERVATION] = obs
|
transition[TransitionKey.OBSERVATION] = obs
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||||
return transition
|
return transition
|
||||||
@@ -1467,9 +1296,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
"video_modality_keys": self.video_modality_keys,
|
"video_modality_keys": self.video_modality_keys,
|
||||||
"raw_stats": self.raw_stats,
|
"raw_stats": self.raw_stats,
|
||||||
"modality_config": self.modality_config,
|
"modality_config": self.modality_config,
|
||||||
"use_relative_actions": self.use_relative_actions,
|
|
||||||
"relative_exclude_joints": self.relative_exclude_joints,
|
|
||||||
"action_feature_names": self.action_feature_names,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
||||||
@@ -1477,23 +1303,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
|
|
||||||
return self._last_raw_state
|
return self._last_raw_state
|
||||||
|
|
||||||
def get_cached_reference_state(self) -> torch.Tensor | None:
|
|
||||||
"""Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion."""
|
|
||||||
|
|
||||||
return self._last_reference_state
|
|
||||||
|
|
||||||
def get_reference_holder(self) -> "_GrootRelativeRefHolder":
|
|
||||||
"""Return the runtime-only holder shared with the policy (writer) and decode step (reader)."""
|
|
||||||
|
|
||||||
return self._ref_holder
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Clear cached per-episode relative-action references (sync engine resets on episode boundaries)."""
|
|
||||||
|
|
||||||
self._last_reference_state = None
|
|
||||||
self._last_raw_state = None
|
|
||||||
self._ref_holder.clear()
|
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
if not self.stats:
|
if not self.stats:
|
||||||
return {}
|
return {}
|
||||||
@@ -1524,12 +1333,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
||||||
an image item in the same chat message so the resulting image tokens match
|
an image item in the same chat message so the resulting image tokens match
|
||||||
the temporal VLM packing used by Isaac-GR00T.
|
the temporal VLM packing used by Isaac-GR00T.
|
||||||
|
|
||||||
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
|
||||||
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
|
||||||
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
|
|
||||||
single CPU main-loop thread. This keeps the output bit-identical on CPU and
|
|
||||||
moves the dominant preprocessing cost off the critical path on GPU.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
||||||
@@ -1538,7 +1341,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
shortest_image_edge: int | None = None
|
shortest_image_edge: int | None = None
|
||||||
crop_fraction: float | None = None
|
crop_fraction: float | None = None
|
||||||
use_albumentations: bool = False
|
use_albumentations: bool = False
|
||||||
device: str | None = None
|
|
||||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1547,70 +1349,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
self._proc = _build_n1_7_processor(self.model_name)
|
self._proc = _build_n1_7_processor(self.model_name)
|
||||||
return self._proc
|
return self._proc
|
||||||
|
|
||||||
def _target_device(self) -> torch.device | None:
|
|
||||||
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
|
||||||
if self.device is None or self.use_albumentations:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return get_safe_torch_device(self.device)
|
|
||||||
except (AssertionError, RuntimeError):
|
|
||||||
# A device serialized at train time (e.g. "cuda") may be unavailable
|
|
||||||
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
|
|
||||||
# this step is not in the standard device-override set. Fall back to
|
|
||||||
# the CPU path, which is bit-identical, instead of crashing.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _build_sample_images(
|
|
||||||
self, video: Any, batch_size: int, target_device: torch.device | None
|
|
||||||
) -> list[list[Any]]:
|
|
||||||
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
|
||||||
|
|
||||||
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
|
||||||
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
|
||||||
``target_device`` when set) for the torchvision-backed Qwen processor.
|
|
||||||
"""
|
|
||||||
if self.use_albumentations:
|
|
||||||
video_np = np.asarray(video)
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
_transform_n1_7_image_for_vlm(
|
|
||||||
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
|
||||||
image_crop_size=self.image_crop_size,
|
|
||||||
image_target_size=self.image_target_size,
|
|
||||||
shortest_image_edge=self.shortest_image_edge,
|
|
||||||
crop_fraction=self.crop_fraction,
|
|
||||||
use_albumentations=True,
|
|
||||||
)
|
|
||||||
for timestep in range(video_np.shape[1])
|
|
||||||
for view_idx in range(video_np.shape[2])
|
|
||||||
]
|
|
||||||
for batch_idx in range(batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
|
||||||
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
|
||||||
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
|
|
||||||
if target_device is not None and video_t.device != target_device:
|
|
||||||
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
|
|
||||||
|
|
||||||
frames_per_sample: list[list[Any]] = []
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
sample = video_t[batch_idx] # (T, V, C, H, W)
|
|
||||||
frames_per_sample.append(
|
|
||||||
[
|
|
||||||
_transform_n1_7_image_for_vlm_torch(
|
|
||||||
sample[timestep, view_idx],
|
|
||||||
image_crop_size=self.image_crop_size,
|
|
||||||
image_target_size=self.image_target_size,
|
|
||||||
shortest_image_edge=self.shortest_image_edge,
|
|
||||||
crop_fraction=self.crop_fraction,
|
|
||||||
)
|
|
||||||
for timestep in range(sample.shape[0])
|
|
||||||
for view_idx in range(sample.shape[1])
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return frames_per_sample
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||||
@@ -1618,25 +1356,33 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
if video is None:
|
if video is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
batch_size = int(video.shape[0])
|
|
||||||
languages = _prepare_n1_7_language_batch(
|
languages = _prepare_n1_7_language_batch(
|
||||||
comp.get("language"),
|
comp.get("language"),
|
||||||
batch_size,
|
video.shape[0],
|
||||||
formalize_language=False,
|
formalize_language=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
target_device = self._target_device()
|
|
||||||
sample_images = self._build_sample_images(video, batch_size, target_device)
|
|
||||||
|
|
||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
images: list[Any] = []
|
images: list[Image.Image] = []
|
||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(video.shape[0]):
|
||||||
frames = sample_images[batch_idx]
|
sample = video[batch_idx] # (T, V, H, W, C)
|
||||||
|
sample_images = [
|
||||||
|
_transform_n1_7_image_for_vlm(
|
||||||
|
Image.fromarray(sample[timestep, view_idx]),
|
||||||
|
image_crop_size=self.image_crop_size,
|
||||||
|
image_target_size=self.image_target_size,
|
||||||
|
shortest_image_edge=self.shortest_image_edge,
|
||||||
|
crop_fraction=self.crop_fraction,
|
||||||
|
use_albumentations=self.use_albumentations,
|
||||||
|
)
|
||||||
|
for timestep in range(sample.shape[0])
|
||||||
|
for view_idx in range(sample.shape[1])
|
||||||
|
]
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
*[{"type": "image", "image": image} for image in frames],
|
*[{"type": "image", "image": image} for image in sample_images],
|
||||||
{"type": "text", "text": languages[batch_idx]},
|
{"type": "text", "text": languages[batch_idx]},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -1648,17 +1394,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
images.extend(frames)
|
images.extend(sample_images)
|
||||||
|
|
||||||
proc_kwargs: dict[str, Any] = {
|
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
"text": texts,
|
|
||||||
"images": images,
|
|
||||||
"return_tensors": "pt",
|
|
||||||
"padding": True,
|
|
||||||
}
|
|
||||||
if target_device is not None:
|
|
||||||
proc_kwargs["device"] = str(target_device)
|
|
||||||
encoded = self.proc(**proc_kwargs)
|
|
||||||
for key, value in encoded.items():
|
for key, value in encoded.items():
|
||||||
comp[key] = value
|
comp[key] = value
|
||||||
obs.pop("video", None)
|
obs.pop("video", None)
|
||||||
@@ -1677,7 +1415,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
"shortest_image_edge": self.shortest_image_edge,
|
"shortest_image_edge": self.shortest_image_edge,
|
||||||
"crop_fraction": self.crop_fraction,
|
"crop_fraction": self.crop_fraction,
|
||||||
"use_albumentations": self.use_albumentations,
|
"use_albumentations": self.use_albumentations,
|
||||||
"device": self.device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1828,6 +1565,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
modality_config: dict[str, Any] | None = None
|
modality_config: dict[str, Any] | None = None
|
||||||
use_percentiles: bool = False
|
use_percentiles: bool = False
|
||||||
use_relative_action: bool = False
|
use_relative_action: bool = False
|
||||||
|
# Unused: kept so serialized configs that include it still load.
|
||||||
|
state_cache_key: str = ""
|
||||||
action_decode_transform: str | None = None
|
action_decode_transform: str | None = None
|
||||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||||
|
|
||||||
@@ -1890,14 +1629,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
start_idx += dim
|
start_idx += dim
|
||||||
|
|
||||||
if self.use_relative_action:
|
if self.use_relative_action:
|
||||||
# Prefer the raw state frozen at the chunk-prediction event (see the relative-action
|
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
||||||
# branch of GrootActionUnpackUnnormalizeStep). Falls back to the live cached raw state.
|
|
||||||
holder = self.pack_step.get_reference_holder() if self.pack_step is not None else None
|
|
||||||
raw_state = None
|
|
||||||
if holder is not None:
|
|
||||||
raw_state = holder.frozen_raw if holder.frozen_raw is not None else holder.raw_state
|
|
||||||
if raw_state is None and self.pack_step is not None:
|
|
||||||
raw_state = self.pack_step.get_cached_raw_state()
|
|
||||||
if raw_state is None:
|
if raw_state is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
||||||
@@ -1962,10 +1694,10 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||||
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||||
# silently load into it (v1 is stubbed below with the removal guidance).
|
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||||
@dataclass
|
|
||||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||||
env_action_dim: int = 0
|
env_action_dim: int = 0
|
||||||
@@ -1975,13 +1707,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
|||||||
clip_normalized_action: bool = False
|
clip_normalized_action: bool = False
|
||||||
libero_gripper_action: bool = False
|
libero_gripper_action: bool = False
|
||||||
libero_gripper_binarize: bool = True
|
libero_gripper_binarize: bool = True
|
||||||
# Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the
|
|
||||||
# min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute
|
|
||||||
# using the reference state cached by the linked pack_step (re-linked on reload).
|
|
||||||
use_relative_actions: bool = False
|
|
||||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
|
||||||
action_feature_names: list[str] | None = None
|
|
||||||
pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False)
|
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
|
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
|
||||||
@@ -2021,35 +1746,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
|||||||
inv = (action + 1.0) * 0.5 * safe_denom + min_v
|
inv = (action + 1.0) * 0.5 * safe_denom + min_v
|
||||||
action = torch.where(mask, inv, min_v)
|
action = torch.where(mask, inv, min_v)
|
||||||
|
|
||||||
# Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the
|
|
||||||
# reference state cached by the linked pack step. The link is restored on reload by
|
|
||||||
# _reconnect_groot_n1_7_pack_decode_steps.
|
|
||||||
if self.use_relative_actions:
|
|
||||||
if self.pack_step is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked "
|
|
||||||
"GrootN17PackInputsStep to read the cached reference state, but pack_step is None. "
|
|
||||||
"Build both pipelines through make_groot_pre_post_processors (or load them together "
|
|
||||||
"via make_groot_pre_post_processors_from_pretrained)."
|
|
||||||
)
|
|
||||||
# Prefer the reference frozen at the chunk-prediction event (set by
|
|
||||||
# GrootPolicy.predict_action_chunk via the shared holder) so every popped delta of a
|
|
||||||
# chunk reconstructs against that chunk's start state S_T, not the per-tick latest
|
|
||||||
# state. Falls back to the live reference when nothing was frozen (e.g. decode without
|
|
||||||
# a preceding predict event, or RTC/async where frozen == live).
|
|
||||||
holder = self.pack_step.get_reference_holder()
|
|
||||||
ref = holder.frozen_reference if holder.frozen_reference is not None else holder.reference_state
|
|
||||||
if ref is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state "
|
|
||||||
"cached by its connected GrootN17PackInputsStep to convert relative actions back to "
|
|
||||||
"absolute. Run the preprocessor on an observation before decoding actions."
|
|
||||||
)
|
|
||||||
relative_mask = _build_relative_action_mask(
|
|
||||||
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
|
|
||||||
)
|
|
||||||
action = to_absolute_actions(action, ref, relative_mask)
|
|
||||||
|
|
||||||
if self.libero_gripper_action and action.shape[-1] >= 7:
|
if self.libero_gripper_action and action.shape[-1] >= 7:
|
||||||
gripper = action[..., -1]
|
gripper = action[..., -1]
|
||||||
if self.libero_gripper_binarize:
|
if self.libero_gripper_binarize:
|
||||||
@@ -2077,9 +1773,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
|||||||
"clip_normalized_action": self.clip_normalized_action,
|
"clip_normalized_action": self.clip_normalized_action,
|
||||||
"libero_gripper_action": self.libero_gripper_action,
|
"libero_gripper_action": self.libero_gripper_action,
|
||||||
"libero_gripper_binarize": self.libero_gripper_binarize,
|
"libero_gripper_binarize": self.libero_gripper_binarize,
|
||||||
"use_relative_actions": self.use_relative_actions,
|
|
||||||
"relative_exclude_joints": self.relative_exclude_joints,
|
|
||||||
"action_feature_names": self.action_feature_names,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||||
|
|||||||
@@ -207,6 +207,11 @@ def test_lerobot_groot_forward_pass():
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||||
|
|
||||||
|
assert isinstance(lerobot_loss, torch.Tensor)
|
||||||
|
assert torch.isfinite(lerobot_loss).all()
|
||||||
|
assert "loss" in lerobot_metrics
|
||||||
|
assert np.isfinite(lerobot_metrics["loss"])
|
||||||
|
|
||||||
print("\nForward pass successful.")
|
print("\nForward pass successful.")
|
||||||
print(f" - Loss: {lerobot_loss.item():.6f}")
|
print(f" - Loss: {lerobot_loss.item():.6f}")
|
||||||
print(f" - Metrics: {lerobot_metrics}")
|
print(f" - Metrics: {lerobot_metrics}")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -14,31 +14,36 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||||
|
|
||||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
|
||||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
once in the original ``gr00t`` env by the companion script
|
||||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
|
||||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
|
||||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
|
||||||
in the checkpoint.
|
|
||||||
|
|
||||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
|
||||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
action head + Qwen3-VL backbone must produce the SAME raw model output
|
||||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
(``action_pred``, the normalized flow-matching prediction before any action
|
||||||
to per-tag ``.npz`` files.
|
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
|
||||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
pre-processed inputs and the flow-matching seed recorded in the artifact.
|
||||||
model, and compares.
|
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
|
||||||
|
template / tokenizer / image packing + state normalization, no mocks) must produce
|
||||||
|
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
|
||||||
|
as the original package's processor, given the identical raw observations
|
||||||
|
(images, state, language) recorded in the artifact. Artifacts written by older
|
||||||
|
versions of the dump script carry no raw observations; this case then SKIPS with
|
||||||
|
a regeneration hint.
|
||||||
|
|
||||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
|
||||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
present, or when no artifact has been generated. By default they look for artifacts in
|
||||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||||
for the full run procedure.
|
for the full run procedure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -50,7 +55,9 @@ pytestmark = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||||
|
|
||||||
|
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
|
||||||
SEED = 42
|
SEED = 42
|
||||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||||
@@ -60,6 +67,11 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
|||||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||||
_ARTIFACT_SUFFIX = ".npz"
|
_ARTIFACT_SUFFIX = ".npz"
|
||||||
|
|
||||||
|
# Collated keys compared by the preprocessor parity case: integer/id tensors must
|
||||||
|
# match exactly; float tensors within ATOL/RTOL.
|
||||||
|
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
|
||||||
|
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
|
||||||
|
|
||||||
|
|
||||||
def _artifact_dir() -> Path:
|
def _artifact_dir() -> Path:
|
||||||
"""Directory holding the per-embodiment .npz artifacts.
|
"""Directory holding the per-embodiment .npz artifacts.
|
||||||
@@ -109,9 +121,20 @@ def _resolve_checkpoint() -> str:
|
|||||||
return str(ckpt)
|
return str(ckpt)
|
||||||
|
|
||||||
|
|
||||||
def _load_artifact(path: Path):
|
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
|
||||||
|
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
|
||||||
data = np.load(path, allow_pickle=True)
|
data = np.load(path, allow_pickle=True)
|
||||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||||
|
if "seed" in data.files:
|
||||||
|
seed = int(data["seed"])
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
|
||||||
|
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
|
||||||
|
"regenerate the artifact with the current dump script.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
seed = SEED
|
||||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||||
inputs = {}
|
inputs = {}
|
||||||
for key in data.files:
|
for key in data.files:
|
||||||
@@ -124,7 +147,45 @@ def _load_artifact(path: Path):
|
|||||||
if "int" in declared or "long" in declared:
|
if "int" in declared or "long" in declared:
|
||||||
t = t.long()
|
t = t.long()
|
||||||
inputs[name] = t
|
inputs[name] = t
|
||||||
return original_action, inputs
|
return original_action, inputs, seed
|
||||||
|
|
||||||
|
|
||||||
|
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
|
||||||
|
"""Return the raw observation recorded in the artifact, or None for old artifacts.
|
||||||
|
|
||||||
|
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
|
||||||
|
exact raw observation the producer fed to the original processor: per-camera uint8
|
||||||
|
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
|
||||||
|
(``raw::state.<key>``, (B, T, dim)) and the language instruction
|
||||||
|
(``raw::language``, one string per batch element). ``raw_video_keys`` /
|
||||||
|
``raw_state_keys`` record the checkpoint modality-key order.
|
||||||
|
"""
|
||||||
|
data = np.load(path, allow_pickle=True)
|
||||||
|
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
|
||||||
|
if any(marker not in data.files for marker in markers):
|
||||||
|
return None
|
||||||
|
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
|
||||||
|
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
|
||||||
|
return {
|
||||||
|
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
|
||||||
|
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
|
||||||
|
"language": [str(t) for t in data["raw::language"].tolist()],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Convert the producer's raw observation into a LeRobot policy batch."""
|
||||||
|
batch: dict[str, Any] = {}
|
||||||
|
for key, frames in raw["video"].items():
|
||||||
|
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
|
||||||
|
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
|
||||||
|
# observation.state is the per-key state vectors (latest frame) concatenated in
|
||||||
|
# checkpoint modality-key order -- the layout the LeRobot pack step and the
|
||||||
|
# flattened checkpoint statistics expect.
|
||||||
|
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
|
||||||
|
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
|
||||||
|
batch["task"] = list(raw["language"])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||||
@@ -139,6 +200,36 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
|||||||
return nested.get("inputs", nested)
|
return nested.get("inputs", nested)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_collated_parity(
|
||||||
|
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
|
||||||
|
) -> None:
|
||||||
|
"""Compare one collated tensor produced by LeRobot against the original's."""
|
||||||
|
assert isinstance(lerobot_value, torch.Tensor), (
|
||||||
|
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
|
||||||
|
f"{type(lerobot_value).__name__}, expected a tensor."
|
||||||
|
)
|
||||||
|
lerobot_t = lerobot_value.detach().cpu()
|
||||||
|
original_t = original_value.detach().cpu()
|
||||||
|
assert lerobot_t.shape == original_t.shape, (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
|
||||||
|
f"original={tuple(original_t.shape)}."
|
||||||
|
)
|
||||||
|
if exact:
|
||||||
|
mismatched = int((lerobot_t.long() != original_t.long()).sum())
|
||||||
|
assert mismatched == 0, (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
|
||||||
|
f"{mismatched}/{original_t.numel()} elements mismatch."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lerobot_f, original_f = lerobot_t.float(), original_t.float()
|
||||||
|
max_diff = (lerobot_f - original_f).abs().max().item()
|
||||||
|
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
|
||||||
|
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
|
||||||
|
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
|
||||||
|
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def lerobot_model():
|
def lerobot_model():
|
||||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||||
@@ -165,8 +256,7 @@ def lerobot_model():
|
|||||||
|
|
||||||
_ARTIFACTS = _discover_artifacts()
|
_ARTIFACTS = _discover_artifacts()
|
||||||
|
|
||||||
|
_requires_artifacts = pytest.mark.skipif(
|
||||||
@pytest.mark.skipif(
|
|
||||||
not _ARTIFACTS,
|
not _ARTIFACTS,
|
||||||
reason=(
|
reason=(
|
||||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||||
@@ -174,24 +264,30 @@ _ARTIFACTS = _discover_artifacts()
|
|||||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_requires_artifacts
|
||||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||||
original_action, flat_inputs = _load_artifact(artifact)
|
original_action, flat_inputs, seed = _load_artifact(artifact)
|
||||||
model_inputs = _unflatten(flat_inputs)
|
model_inputs = _unflatten(flat_inputs)
|
||||||
|
|
||||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||||
torch.manual_seed(SEED)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(SEED)
|
torch.cuda.manual_seed_all(seed)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
out = lerobot_model.get_action(model_inputs)
|
out = lerobot_model.get_action(model_inputs)
|
||||||
lerobot_action = out["action_pred"].float().cpu()
|
lerobot_action = out["action_pred"].float().cpu()
|
||||||
|
|
||||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
assert lerobot_action.shape == original_action.shape, (
|
||||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
|
||||||
original_action = original_action[:, :t, :d]
|
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
|
||||||
lerobot_action = lerobot_action[:, :t, :d]
|
"The same checkpoint and inputs must produce identical shapes; this indicates an "
|
||||||
|
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
|
||||||
|
"utils/dump_original_n1_7.py)."
|
||||||
|
)
|
||||||
|
|
||||||
diff = torch.abs(lerobot_action - original_action)
|
diff = torch.abs(lerobot_action - original_action)
|
||||||
max_diff = diff.max().item()
|
max_diff = diff.max().item()
|
||||||
@@ -205,3 +301,56 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
|||||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_requires_artifacts
|
||||||
|
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||||
|
def test_groot_preprocessor_parity(embodiment_tag, artifact):
|
||||||
|
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
|
||||||
|
|
||||||
|
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
|
||||||
|
template, tokenizer and image packing plus the checkpoint-driven state
|
||||||
|
normalization (no mocks) -- on the raw observations recorded in the artifact, and
|
||||||
|
compares every collated model input against the ones the original ``gr00t``
|
||||||
|
processor produced from the same raw observations.
|
||||||
|
"""
|
||||||
|
raw = _load_raw_observation(artifact)
|
||||||
|
if raw is None:
|
||||||
|
pytest.skip(
|
||||||
|
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
|
||||||
|
"not record raw observations; regenerate it with the current dump script to run the "
|
||||||
|
"preprocessor parity case."
|
||||||
|
)
|
||||||
|
_, flat_inputs, _ = _load_artifact(artifact)
|
||||||
|
original_inputs = _unflatten(flat_inputs)
|
||||||
|
|
||||||
|
ckpt = _resolve_checkpoint()
|
||||||
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
|
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
|
||||||
|
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
|
||||||
|
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||||
|
|
||||||
|
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
|
||||||
|
|
||||||
|
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
|
||||||
|
missing_original = [k for k in compared_keys if k not in original_inputs]
|
||||||
|
missing_lerobot = [k for k in compared_keys if k not in processed]
|
||||||
|
assert not missing_original, (
|
||||||
|
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
|
||||||
|
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
|
||||||
|
)
|
||||||
|
assert not missing_lerobot, (
|
||||||
|
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
|
||||||
|
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
|
||||||
|
)
|
||||||
|
|
||||||
|
for name in compared_keys:
|
||||||
|
_assert_collated_parity(
|
||||||
|
embodiment_tag,
|
||||||
|
name,
|
||||||
|
processed[name],
|
||||||
|
original_inputs[name],
|
||||||
|
exact=name in _COLLATED_EXACT_KEYS,
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
|
|||||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||||
|
|
||||||
|
* the RAW observation fed to the original processor (per-camera uint8 frames,
|
||||||
|
per-key state vectors, the language instruction), so the LeRobot side can also
|
||||||
|
run its OWN preprocessor on identical raw inputs and compare collated tensors,
|
||||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||||
* the random seed used right before the flow-matching sampler,
|
* the random seed used right before the flow-matching sampler,
|
||||||
@@ -21,8 +24,10 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
|
|||||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||||
``libero_sim``.
|
``libero_sim``.
|
||||||
|
|
||||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
|
||||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
|
||||||
|
(model parity), and the raw observation is replayed through LeRobot's own
|
||||||
|
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||||
@@ -62,10 +67,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
|||||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||||
# present-but-empty so the processor's state transform finds every expected key.
|
# present-but-empty so the processor's state transform finds every expected key.
|
||||||
state = {
|
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
|
||||||
for k, dim in state_spec
|
|
||||||
}
|
|
||||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||||
return {"video": video, "state": state, "language": language}
|
return {"video": video, "state": state, "language": language}
|
||||||
|
|
||||||
@@ -77,6 +79,25 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
|||||||
lang_key = modality_cfg["language"].modality_keys[0]
|
lang_key = modality_cfg["language"].modality_keys[0]
|
||||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||||
|
|
||||||
|
# Snapshot the RAW observation exactly as fed to the original processor below. The
|
||||||
|
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
|
||||||
|
# and compares the resulting collated tensors against the "in::" ones saved further
|
||||||
|
# down. raw_state_keys records the checkpoint modality-key order, which is the
|
||||||
|
# concatenation order of the flat LeRobot ``observation.state`` vector.
|
||||||
|
spec_keys = [key for key, _ in state_spec]
|
||||||
|
state_modality = modality_cfg.get("state")
|
||||||
|
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
|
||||||
|
state_keys += [key for key in spec_keys if key not in state_keys]
|
||||||
|
raw_language = [
|
||||||
|
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
||||||
|
for item in observation["language"][lang_key]
|
||||||
|
]
|
||||||
|
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
|
||||||
|
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
|
||||||
|
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
|
||||||
|
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
|
||||||
|
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
|
||||||
|
|
||||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||||
policy.modality_configs = {
|
policy.modality_configs = {
|
||||||
@@ -136,6 +157,7 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
|||||||
embodiment_tag=np.array(tag),
|
embodiment_tag=np.array(tag),
|
||||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||||
|
**raw_flat,
|
||||||
**flat,
|
**flat,
|
||||||
)
|
)
|
||||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||||
@@ -181,7 +203,12 @@ def main():
|
|||||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||||
try:
|
try:
|
||||||
dump_one_tag(
|
dump_one_tag(
|
||||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
policy,
|
||||||
|
fair_model,
|
||||||
|
tag,
|
||||||
|
all_modality[tag],
|
||||||
|
state_spec,
|
||||||
|
args,
|
||||||
out_dir / f"original_n1_7_{tag}.npz",
|
out_dir / f"original_n1_7_{tag}.npz",
|
||||||
)
|
)
|
||||||
done.append(tag)
|
done.append(tag)
|
||||||
|
|||||||
Reference in New Issue
Block a user