mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 49cb1ee7db | |||
| b23b6edcd9 | |||
| d7b09f77c5 | |||
| 34e70f43b8 | |||
| a35e6a4b46 |
@@ -218,6 +218,7 @@ groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft-dep]",
|
||||
"lerobot[diffusers-dep]",
|
||||
"lerobot[dataset]", # NOTE: processor_groot builds a LeRobotDataset for relative-action training stats
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
|
||||
@@ -324,9 +324,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Enable GR00T-style state-relative action chunks. Prefer deriving action representation from
|
||||
# embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it.
|
||||
# Enable GR00T-style state-relative action chunks (action chunk expressed relative to the current
|
||||
# observation state).
|
||||
use_relative_actions: bool = False
|
||||
|
||||
# relative_exclude_joints names the action dimensions that stay absolute; the
|
||||
# match is substring/case-insensitive against the dataset action feature names. With the empty
|
||||
# default every dimension is treated as relative, including the gripper -- set e.g. ["gripper"] to
|
||||
# keep the gripper absolute, matching the Isaac-GR00T single-arm + absolute-gripper convention.
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
|
||||
# Training parameters
|
||||
|
||||
@@ -20,13 +20,14 @@ from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.v2.functional as tv_functional
|
||||
from einops import rearrange
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
from lerobot.utils.import_utils import _datasets_available, _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import (
|
||||
@@ -43,6 +44,11 @@ else:
|
||||
Qwen3VLProcessor = None
|
||||
Qwen3VLVideoProcessor = None
|
||||
|
||||
if TYPE_CHECKING or _datasets_available:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
else:
|
||||
LeRobotDataset = None
|
||||
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
@@ -810,7 +816,7 @@ def _make_relative_action_training_stats_from_dataset_meta(
|
||||
if dataset_meta is None or repo_id is None or root is None or fps is None:
|
||||
return None
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
require_package("datasets", extra="groot")
|
||||
|
||||
delta_timestamps = {ACTION: [index / fps for index in config.action_delta_indices]}
|
||||
dataset = LeRobotDataset(
|
||||
@@ -996,6 +1002,7 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
}
|
||||
for group in groups
|
||||
]
|
||||
# 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B)
|
||||
action_horizon = min(config.chunk_size, 40)
|
||||
modality_config: dict[str, Any] = {
|
||||
"state": {"modality_keys": [group.key for group in groups]},
|
||||
@@ -1194,6 +1201,13 @@ def make_groot_pre_post_processors(
|
||||
)
|
||||
relative_step: RelativeActionsProcessorStep | None = None
|
||||
if config.use_relative_actions and not uses_native_relative_actions:
|
||||
logging.warning(
|
||||
"GR00T relative actions are using the generic RelativeActionsProcessorStep fallback because "
|
||||
"the checkpoint already carries non-relative statistics. Relative deltas will be normalized "
|
||||
"with absolute action stats rather than Isaac-GR00T's per-horizon relative stats. For "
|
||||
"OSS-faithful relative normalization, build from a checkpoint without baked-in stats (or "
|
||||
"pass dataset_meta) so native relative stats are computed."
|
||||
)
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=list(config.relative_exclude_joints or []),
|
||||
@@ -1317,13 +1331,6 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
|
||||
target_h, target_w = image_target_size
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
||||
) from exc
|
||||
|
||||
image_np = np.asarray(image)
|
||||
if image_np.ndim == 2:
|
||||
image_np = np.repeat(image_np[:, :, None], 3, axis=2)
|
||||
@@ -1658,6 +1665,25 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
return None
|
||||
return torch.cat(normalized_groups, dim=-1)
|
||||
|
||||
def _uses_relative_action_groups(self) -> bool:
|
||||
"""True when the action modality declares at least one relative group.
|
||||
|
||||
Relative groups normalize with per-chunk-timestep (2D) ``relative_action`` stats, which the
|
||||
flat ``_min_max_norm`` fallback cannot honor, so a relative config that fails grouped
|
||||
normalization must fail loudly rather than silently wrongly scale every timestep.
|
||||
"""
|
||||
if not isinstance(self.modality_config, dict):
|
||||
return False
|
||||
action_config = self.modality_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return False
|
||||
action_configs = action_config.get("action_configs", [])
|
||||
if not isinstance(action_configs, list):
|
||||
return False
|
||||
return any(
|
||||
isinstance(cfg, dict) and config_value(cfg.get("rep")) == "relative" for cfg in action_configs
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
@@ -1775,6 +1801,15 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
normalized_action = self._normalize_action_groups_for_training(action)
|
||||
if normalized_action is not None:
|
||||
action = normalized_action
|
||||
elif self._uses_relative_action_groups():
|
||||
raise ValueError(
|
||||
"GrootN17PackInputsStep could not apply native grouped normalization to a "
|
||||
"relative-action chunk: the action layout or horizon does not match the "
|
||||
f"checkpoint relative_action stats (action shape {tuple(action.shape)}). The flat "
|
||||
"min/max fallback cannot honor per-chunk-timestep relative stats, so refusing to "
|
||||
"silently wrongly normalize. Recompute the relative action stats so their horizon and "
|
||||
"dimensions match the action chunk."
|
||||
)
|
||||
else:
|
||||
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
|
||||
action = flat.view(bsz, horizon, dim)
|
||||
|
||||
@@ -129,6 +129,7 @@ _placo_available = is_package_available("placo")
|
||||
_hidapi_available = is_package_available("hidapi", import_name="hid")
|
||||
|
||||
# Data / serialization
|
||||
_datasets_available = is_package_available("datasets")
|
||||
_pandas_available = is_package_available("pandas")
|
||||
_faker_available = is_package_available("faker")
|
||||
|
||||
|
||||
@@ -30,10 +30,12 @@ from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.policies.factory import make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.groot.configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
GROOT_N1_7,
|
||||
GROOT_N1_7_BASE_MODEL,
|
||||
GrootConfig,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import (
|
||||
@@ -92,6 +94,8 @@ def _raw_n1_7_libero_config(model_path) -> GrootConfig:
|
||||
|
||||
|
||||
def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_ids(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
@@ -185,6 +189,8 @@ def test_n1_7_backbone_accepts_transformers_5_layout_and_forwards_mm_token_type_
|
||||
|
||||
|
||||
def test_n1_7_backbone_preserves_missing_qwen_optional_dependency_error(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -350,6 +356,18 @@ def test_groot_defaults_use_n1_7():
|
||||
assert len(config.action_delta_indices) == 40
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
|
||||
def test_groot_normalize_model_version_rejects_n1_5_aliases(legacy_version):
|
||||
# model_version is no longer a GrootConfig field, but normalize_groot_model_version is still
|
||||
# live (e.g. via infer_groot_model_version) and must keep rejecting N1.5 with removal guidance.
|
||||
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
|
||||
normalize_groot_model_version(legacy_version)
|
||||
|
||||
|
||||
def test_groot_normalize_model_version_accepts_n1_7():
|
||||
assert normalize_groot_model_version(GROOT_N1_7) == GROOT_N1_7
|
||||
|
||||
|
||||
def test_groot_n1_7_accepts_named_action_decode_transform():
|
||||
config = GrootConfig(
|
||||
action_decode_transform="libero",
|
||||
@@ -393,6 +411,8 @@ def test_groot_predict_action_chunk_accepts_rtc_kwargs():
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -422,6 +442,8 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -455,6 +477,8 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
@@ -508,6 +532,8 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
|
||||
|
||||
def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
@@ -522,6 +548,8 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc
|
||||
|
||||
|
||||
def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "local-checkpoint"
|
||||
@@ -997,6 +1025,42 @@ def test_groot_n1_7_pack_inputs_normalizes_action_chunk_per_dimension_before_pad
|
||||
assert action_mask[0, :, 3:].sum().item() == 0
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_raises_when_relative_groups_cannot_normalize():
|
||||
# Relative groups carry per-chunk-timestep stats; if the action horizon exceeds the available
|
||||
# stat rows, grouped normalization cannot apply and the flat fallback would silently wrongly scale.
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=3,
|
||||
valid_action_horizon=3,
|
||||
max_state_dim=2,
|
||||
max_action_dim=2,
|
||||
normalize_min_max=True,
|
||||
raw_stats={
|
||||
"state": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
|
||||
"action": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
|
||||
# only one horizon row, but the action chunk has horizon 3
|
||||
"relative_action": {"single_arm": {"min": [[-1.0, -1.0]], "max": [[1.0, 1.0]]}},
|
||||
},
|
||||
modality_config={
|
||||
"state": {"modality_keys": ["single_arm"]},
|
||||
"action": {
|
||||
"modality_keys": ["single_arm"],
|
||||
"action_configs": [
|
||||
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}
|
||||
],
|
||||
"delta_indices": [0, 1, 2],
|
||||
},
|
||||
},
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 2)},
|
||||
TransitionKey.ACTION: torch.zeros(1, 3, 2),
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="could not apply native grouped normalization"):
|
||||
step(transition)
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_gripper():
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=2,
|
||||
@@ -2192,7 +2256,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
|
||||
assert kwargs["delta_timestamps"][ACTION] == [0.0, 1 / runtime_meta.fps]
|
||||
return _RelativeStatsDataset()
|
||||
|
||||
monkeypatch.setattr("lerobot.datasets.lerobot_dataset.LeRobotDataset", _fake_lerobot_dataset)
|
||||
monkeypatch.setattr("lerobot.policies.groot.processor_groot.LeRobotDataset", _fake_lerobot_dataset)
|
||||
config._runtime_dataset_meta = runtime_meta
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=absolute_dataset_stats)
|
||||
@@ -2430,6 +2494,8 @@ def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
called = {}
|
||||
@@ -2447,6 +2513,8 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
|
||||
|
||||
def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
dummy_model = _DummyGrootModel()
|
||||
@@ -2505,6 +2573,8 @@ def test_groot_select_action_rejects_relative_action_policies():
|
||||
|
||||
|
||||
def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkeypatch):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
@@ -2697,6 +2767,8 @@ def test_qwen3_backbone_can_initialize_from_config_without_downloading_weights(m
|
||||
|
||||
|
||||
def test_gr00t_n1_7_from_pretrained_defers_backbone_weight_loading(monkeypatch, tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from huggingface_hub.errors import HFValidationError
|
||||
|
||||
import lerobot.policies.groot.groot_n1_7 as groot_n1_7
|
||||
|
||||
@@ -21,7 +21,6 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from lerobot.policies.groot.action_head.cross_attention_dit import AlternateVLDiT
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
@@ -110,6 +109,8 @@ def test_groot_n1_7_vlm_chat_content_order_matches_oss_reference():
|
||||
def test_groot_n1_7_alternate_vl_dit_matches_oss_reference():
|
||||
"""Run the LeRobot DiT with native OSS weights and identical inputs."""
|
||||
|
||||
pytest.importorskip("diffusers")
|
||||
|
||||
fixture = torch.load(_fixture_path("alternate_vl_dit_small.pt"), map_location="cpu", weights_only=True)
|
||||
model = AlternateVLDiT(
|
||||
output_dim=8,
|
||||
@@ -228,6 +229,10 @@ def test_groot_n1_7_qwen_backbone_matches_oss_checkpoint_reference():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("The 3B OSS Qwen parity test requires CUDA.")
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
fixture = torch.load(_fixture_path("qwen_backbone_so101.pt"), map_location="cpu", weights_only=True)
|
||||
model = GR00TN17.from_pretrained(checkpoint).to(device="cuda", dtype=torch.bfloat16).eval()
|
||||
backbone_input = BatchFeature(
|
||||
|
||||
@@ -62,10 +62,7 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
@@ -181,7 +178,12 @@ def main():
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
policy,
|
||||
fair_model,
|
||||
tag,
|
||||
all_modality[tag],
|
||||
state_spec,
|
||||
args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
|
||||
@@ -2957,11 +2957,17 @@ gamepad = [
|
||||
{ name = "pygame" },
|
||||
]
|
||||
groot = [
|
||||
{ name = "av" },
|
||||
{ name = "datasets" },
|
||||
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dm-tree" },
|
||||
{ name = "jsonlines" },
|
||||
{ name = "pandas" },
|
||||
{ name = "peft" },
|
||||
{ name = "pyarrow" },
|
||||
{ name = "timm" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
grpcio-dep = [
|
||||
@@ -3240,6 +3246,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'annotations'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" },
|
||||
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" },
|
||||
|
||||
Reference in New Issue
Block a user