From 57d4cd484013a85718fd3d4766af79dfffdb9683 Mon Sep 17 00:00:00 2001 From: Andy Wrenn Date: Sat, 20 Jun 2026 06:13:56 -0700 Subject: [PATCH] Fix GROOT relative action training stats --- .../policies/groot/configuration_groot.py | 13 +- src/lerobot/policies/groot/processor_groot.py | 16 +- src/lerobot/scripts/lerobot_train.py | 182 +++++++++++++++++- tests/policies/groot/test_groot_n1_7.py | 102 ++++++++++ 4 files changed, 303 insertions(+), 10 deletions(-) diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 6ac91fdb6..004fa4be7 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -268,6 +268,7 @@ class GrootConfig(PreTrainedConfig): ) # Groot-specific model parameters + model_version: str = GROOT_N1_7 # Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and # checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the @@ -324,6 +325,12 @@ class GrootConfig(PreTrainedConfig): # Set to True only after installing a flash-attn build matching your torch/CUDA env. use_flash_attention: bool = False + # Train on state-relative action chunks. The listed joints stay absolute, which is normally used + # for gripper channels whose command frame is not the arm joint state. + use_relative_actions: bool = False + relative_exclude_joints: list[str] = field(default_factory=list) + action_feature_names: list[str] | None = None + # Training parameters optimizer_lr: float = 1e-4 optimizer_betas: tuple[float, float] = (0.95, 0.999) @@ -358,6 +365,8 @@ class GrootConfig(PreTrainedConfig): resume: bool = False def __post_init__(self): + self.model_version = normalize_groot_model_version(self.model_version) + if self.tokenizer_assets_repo is not None: raise ValueError( "Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks " @@ -408,9 +417,9 @@ class GrootConfig(PreTrainedConfig): setattr(self, field_name, n1_7_value) inferred_version = infer_groot_model_version(self.base_model_path) - if inferred_version is not None and inferred_version != GROOT_N1_7: + if inferred_version is not None and inferred_version != self.model_version: message = ( - f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path " + f"GR00T model_version '{self.model_version}' does not match base_model_path " f"'{self.base_model_path}', which looks like '{inferred_version}'." ) if inferred_version == GROOT_N1_5: diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 65ee46954..9fa7575f1 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -655,6 +655,14 @@ def make_groot_pre_post_processors( ), DeviceProcessorStep(device=config.device), ] + relative_step: RelativeActionsProcessorStep | None = None + if config.use_relative_actions: + relative_step = RelativeActionsProcessorStep( + enabled=True, + exclude_joints=list(config.relative_exclude_joints or []), + action_names=config.action_feature_names, + ) + input_steps.insert(2, relative_step) if checkpoint_assets is not None and not checkpoint_has_stats and not has_modality_stats(padded_stats): raise ValueError( @@ -687,10 +695,10 @@ def make_groot_pre_post_processors( action_decode_transform=config.action_decode_transform, ) - output_steps: list[ProcessorStep] = [ - action_decode_step, - DeviceProcessorStep(device="cpu"), - ] + output_steps: list[ProcessorStep] = [action_decode_step] + if relative_step is not None: + output_steps.append(AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)) + output_steps.append(DeviceProcessorStep(device="cpu")) return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4ddef3105..d99564fd4 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -22,6 +22,7 @@ import dataclasses import logging import time from contextlib import nullcontext +from copy import deepcopy from pprint import pformat from typing import TYPE_CHECKING, Any @@ -49,6 +50,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.rewards import make_reward_pre_post_processors from lerobot.utils.collate import lerobot_collate_fn +from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed @@ -161,6 +163,169 @@ def update_policy( return train_metrics, output_dict +def _as_int(value: Any) -> int: + if isinstance(value, torch.Tensor): + return int(value.item()) + item = getattr(value, "item", None) + if callable(item): + return int(item()) + return int(value) + + +def _to_float_tensor(value: Any, *, key: str) -> torch.Tensor: + if value is None: + raise ValueError(f"Cannot compute relative action statistics: sample is missing '{key}'.") + if isinstance(value, torch.Tensor): + return value.detach().cpu().float() + return torch.as_tensor(value, dtype=torch.float32) + + +def _state_reference_batch(state: torch.Tensor) -> torch.Tensor: + if state.ndim == 1: + return state.unsqueeze(0) + if state.ndim == 2: + return state + if state.ndim > 2: + return state.reshape(-1, state.shape[-1])[-1:].contiguous() + raise ValueError(f"observation.state must have at least 1 dimension, got shape {tuple(state.shape)}.") + + +def _action_training_batch(action: torch.Tensor, state_batch: torch.Tensor) -> torch.Tensor: + if action.ndim == 1: + return action.unsqueeze(0) + if action.ndim == 2: + # A single training sample uses (T, D) action chunks with a single (1, D) state reference. + # Batched callers may pass (B, D); keep that shape when the state batch makes it unambiguous. + if state_batch.shape[0] == action.shape[0] and state_batch.shape[0] > 1: + return action + return action.unsqueeze(0) + if action.ndim == 3: + return action + raise ValueError(f"action must be (D,), (T, D), (B, D), or (B, T, D), got {tuple(action.shape)}.") + + +def _unpadded_relative_action_vectors(relative_action: torch.Tensor, pad_mask: Any | None) -> torch.Tensor: + if pad_mask is None: + return relative_action.reshape(-1, relative_action.shape[-1]) + + keep = ~torch.as_tensor(pad_mask, dtype=torch.bool).cpu() + if relative_action.ndim == 3 and keep.ndim == 1 and relative_action.shape[0] == 1: + return relative_action[0, keep] + if relative_action.ndim == 3 and keep.ndim == 2 and tuple(keep.shape) == tuple(relative_action.shape[:2]): + return relative_action[keep] + if relative_action.ndim == 2 and keep.ndim == 1 and keep.numel() == relative_action.shape[0]: + return relative_action[keep] + return relative_action.reshape(-1, relative_action.shape[-1]) + + +def _iter_action_state_training_samples(dataset: Any): + """Yield action chunks, reference states, and action padding masks without decoding videos when possible.""" + + ensure_reader = getattr(dataset, "_ensure_reader", None) + if callable(ensure_reader): + reader = ensure_reader() + if reader.hf_dataset is None: + reader.load_and_activate() + delta_indices = getattr(reader, "delta_indices", None) + for idx in range(len(dataset)): + item = reader.hf_dataset[idx] + action = item.get(ACTION) + state = item.get(OBS_STATE) + pad_mask = None + if delta_indices is not None and ACTION in delta_indices: + ep_idx = _as_int(item["episode_index"]) + abs_idx = _as_int(item["index"]) + query_indices, padding = reader._get_query_indices(abs_idx, ep_idx) + action = reader._query_hf_dataset({ACTION: query_indices[ACTION]})[ACTION] + pad_mask = padding.get(f"{ACTION}_is_pad") + yield action, state, pad_mask + return + + for idx in range(len(dataset)): + item = dataset[idx] + yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad") + + +def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None: + config_names = getattr(active_cfg, "action_feature_names", None) + if config_names is not None: + return list(config_names) + + features = getattr(getattr(dataset, "meta", None), "features", {}) or {} + action_feature = features.get(ACTION) if isinstance(features, dict) else None + if isinstance(action_feature, dict): + names = action_feature.get("names") + else: + names = getattr(action_feature, "names", None) + return list(names) if names is not None else None + + +def _make_relative_action_training_stats( + dataset: Any, + *, + exclude_joints: list[str] | None, + action_names: list[str] | None, +) -> dict[str, dict[str, Any]]: + """Return dataset stats whose action entry describes the relative action tensor used for training.""" + + from lerobot.datasets.compute_stats import RunningQuantileStats + from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep, to_relative_actions + + try: + dataset_len = len(dataset) + except TypeError as exc: + raise ValueError( + "Cannot compute relative action statistics for a dataset without a finite length. " + "Disable streaming or provide precomputed relative action statistics." + ) from exc + + if dataset_len == 0: + raise ValueError("Cannot compute relative action statistics for an empty dataset.") + + stats = deepcopy(getattr(getattr(dataset, "meta", None), "stats", {}) or {}) + running_stats = RunningQuantileStats() + relative_step = RelativeActionsProcessorStep( + enabled=True, + exclude_joints=list(exclude_joints or []), + action_names=action_names, + ) + num_vectors = 0 + + for action_value, state_value, pad_mask in _iter_action_state_training_samples(dataset): + action = _to_float_tensor(action_value, key=ACTION) + state = _to_float_tensor(state_value, key=OBS_STATE) + state_batch = _state_reference_batch(state) + action_batch = _action_training_batch(action, state_batch) + if action_batch.shape[0] != state_batch.shape[0]: + if state_batch.shape[0] == 1: + state_batch = state_batch.expand(action_batch.shape[0], -1) + else: + raise ValueError( + "Cannot compute relative action statistics: action and state batch sizes differ " + f"({action_batch.shape[0]} vs {state_batch.shape[0]})." + ) + + relative_action = to_relative_actions( + action_batch, + state_batch, + relative_step._build_mask(action_batch.shape[-1]), + ) + vectors = _unpadded_relative_action_vectors(relative_action, pad_mask) + if vectors.numel() == 0: + continue + vector_count = int(vectors.reshape(-1, vectors.shape[-1]).shape[0]) + running_stats.update(vectors.numpy()) + num_vectors += vector_count + + if num_vectors < 2: + raise ValueError( + "Cannot compute relative action statistics from fewer than 2 unpadded action vectors." + ) + + stats[ACTION] = running_stats.get_statistics() + return stats + + @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): """ @@ -292,10 +457,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): active_cfg = cfg.trainable_config processor_pretrained_path = active_cfg.pretrained_path + processor_stats = dataset.meta.stats + if not cfg.is_reward_model_training and getattr(active_cfg, "use_relative_actions", False): + if is_main_process: + logging.info("Computing relative-action output statistics for processor normalization") + processor_stats = _make_relative_action_training_stats( + dataset, + exclude_joints=getattr(active_cfg, "relative_exclude_joints", []), + action_names=_resolve_action_feature_names(active_cfg, dataset), + ) processor_kwargs = {} if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path: - processor_kwargs["dataset_stats"] = dataset.meta.stats + processor_kwargs["dataset_stats"] = processor_stats if cfg.is_reward_model_training: processor_kwargs["dataset_meta"] = dataset.meta @@ -304,7 +478,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): preprocessor_overrides = { "device_processor": {"device": device.type}, "normalizer_processor": { - "stats": dataset.meta.stats, + "stats": processor_stats, "features": {**policy.config.input_features, **policy.config.output_features}, "norm_map": policy.config.normalization_mapping, }, @@ -312,7 +486,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): } postprocessor_overrides = { "unnormalizer_processor": { - "stats": dataset.meta.stats, + "stats": processor_stats, "features": policy.config.output_features, "norm_map": policy.config.normalization_mapping, }, @@ -321,7 +495,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): preprocessor_overrides["relative_actions_processor"] = { "enabled": True, "exclude_joints": getattr(active_cfg, "relative_exclude_joints", []), - "action_names": getattr(active_cfg, "action_feature_names", None), + "action_names": _resolve_action_feature_names(active_cfg, dataset), } postprocessor_overrides["absolute_actions_processor"] = {"enabled": True} processor_kwargs["preprocessor_overrides"] = preprocessor_overrides diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index dd6e8eb30..5eba6f075 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -23,6 +23,7 @@ from unittest.mock import patch import numpy as np import pytest import torch +from safetensors.torch import load_file from torch import nn from lerobot.configs import FeatureType, PolicyFeature @@ -49,6 +50,7 @@ from lerobot.processor import ( PolicyProcessorPipeline, RelativeActionsProcessorStep, ) +from lerobot.scripts.lerobot_train import _make_relative_action_training_stats from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE @@ -1754,6 +1756,106 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat +def test_groot_n1_7_relative_action_training_processors_save_relative_action_stats(tmp_path): + input_features, output_features = _groot_features(state_dim=6, action_dim=6) + action_names = [ + "shoulder_pan.pos", + "shoulder_lift.pos", + "elbow_flex.pos", + "wrist_flex.pos", + "wrist_roll.pos", + "gripper.pos", + ] + config = GrootConfig( + input_features=input_features, + output_features=output_features, + device="cpu", + use_bf16=False, + action_decode_transform=None, + use_relative_actions=True, + relative_exclude_joints=["gripper"], + action_feature_names=action_names, + ) + absolute_dataset_stats = { + OBS_STATE: { + "min": torch.tensor([-50.0, -60.0, -70.0, -80.0, -90.0, 0.0]), + "max": torch.tensor([50.0, 60.0, 70.0, 80.0, 90.0, 100.0]), + }, + ACTION: { + "min": torch.tensor([-100.0, -110.0, -120.0, -130.0, -140.0, 0.0]), + "max": torch.tensor([100.0, 110.0, 120.0, 130.0, 140.0, 100.0]), + }, + } + samples = [ + { + OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]), + ACTION: torch.tensor( + [ + [8.0, 17.0, 26.0, 35.0, 44.0, 0.0], + [12.0, 23.0, 34.0, 45.0, 56.0, 100.0], + ] + ), + }, + { + OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]), + ACTION: torch.tensor( + [ + [-1.0, -2.0, -3.0, -4.0, -5.0, 25.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 75.0], + ] + ), + }, + ] + + class _RelativeStatsDataset: + meta = SimpleNamespace( + stats=absolute_dataset_stats, + features={ACTION: {"names": action_names}}, + ) + + def __len__(self): + return len(samples) + + def __getitem__(self, idx): + return samples[idx] + + relative_dataset_stats = _make_relative_action_training_stats( + _RelativeStatsDataset(), + exclude_joints=["gripper"], + action_names=action_names, + ) + expected_relative_action_stats = { + "min": torch.tensor([-2.0, -3.0, -4.0, -5.0, -6.0, 0.0]), + "max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]), + } + + preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats) + preprocessor.save_pretrained(tmp_path) + postprocessor.save_pretrained(tmp_path) + + preprocessor_config = json.loads((tmp_path / "policy_preprocessor.json").read_text()) + assert any(step.get("registry_name") == "relative_actions_processor" for step in preprocessor_config["steps"]) + pack_entry = next( + step + for step in preprocessor_config["steps"] + if step.get("registry_name") == "groot_n1_7_pack_inputs_v1" + ) + pack_state = load_file(tmp_path / pack_entry["state_file"]) + torch.testing.assert_close(pack_state[f"{ACTION}.min"], expected_relative_action_stats["min"]) + torch.testing.assert_close(pack_state[f"{ACTION}.max"], expected_relative_action_stats["max"]) + + postprocessor_config = json.loads((tmp_path / "policy_postprocessor.json").read_text()) + assert any(step.get("registry_name") == "absolute_actions_processor" for step in postprocessor_config["steps"]) + unpack_entry = next( + step + for step in postprocessor_config["steps"] + if step.get("registry_name", "").startswith("groot_action_unpack_unnormalize") + ) + unpack_state = load_file(tmp_path / unpack_entry["state_file"]) + torch.testing.assert_close(unpack_state[f"{ACTION}.min"], expected_relative_action_stats["min"]) + torch.testing.assert_close(unpack_state[f"{ACTION}.max"], expected_relative_action_stats["max"]) + + def test_groot_policy_selects_n1_7_model_class(monkeypatch): from lerobot.policies.groot.groot_n1_7 import GR00TN17