Fix GROOT relative action training stats

This commit is contained in:
Andy Wrenn
2026-06-20 06:13:56 -07:00
parent f89d10e6a7
commit 57d4cd4840
4 changed files with 303 additions and 10 deletions
@@ -268,6 +268,7 @@ class GrootConfig(PreTrainedConfig):
)
# Groot-specific model parameters
model_version: str = GROOT_N1_7
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
@@ -324,6 +325,12 @@ class GrootConfig(PreTrainedConfig):
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
use_flash_attention: bool = False
# Train on state-relative action chunks. The listed joints stay absolute, which is normally used
# for gripper channels whose command frame is not the arm joint state.
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
# Training parameters
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
@@ -358,6 +365,8 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
self.model_version = normalize_groot_model_version(self.model_version)
if self.tokenizer_assets_repo is not None:
raise ValueError(
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
@@ -408,9 +417,9 @@ class GrootConfig(PreTrainedConfig):
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != GROOT_N1_7:
if inferred_version is not None and inferred_version != self.model_version:
message = (
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
+12 -4
View File
@@ -655,6 +655,14 @@ def make_groot_pre_post_processors(
),
DeviceProcessorStep(device=config.device),
]
relative_step: RelativeActionsProcessorStep | None = None
if config.use_relative_actions:
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=list(config.relative_exclude_joints or []),
action_names=config.action_feature_names,
)
input_steps.insert(2, relative_step)
if checkpoint_assets is not None and not checkpoint_has_stats and not has_modality_stats(padded_stats):
raise ValueError(
@@ -687,10 +695,10 @@ def make_groot_pre_post_processors(
action_decode_transform=config.action_decode_transform,
)
output_steps: list[ProcessorStep] = [
action_decode_step,
DeviceProcessorStep(device="cpu"),
]
output_steps: list[ProcessorStep] = [action_decode_step]
if relative_step is not None:
output_steps.append(AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step))
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
+178 -4
View File
@@ -22,6 +22,7 @@ import dataclasses
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pprint import pformat
from typing import TYPE_CHECKING, Any
@@ -49,6 +50,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -161,6 +163,169 @@ def update_policy(
return train_metrics, output_dict
def _as_int(value: Any) -> int:
if isinstance(value, torch.Tensor):
return int(value.item())
item = getattr(value, "item", None)
if callable(item):
return int(item())
return int(value)
def _to_float_tensor(value: Any, *, key: str) -> torch.Tensor:
if value is None:
raise ValueError(f"Cannot compute relative action statistics: sample is missing '{key}'.")
if isinstance(value, torch.Tensor):
return value.detach().cpu().float()
return torch.as_tensor(value, dtype=torch.float32)
def _state_reference_batch(state: torch.Tensor) -> torch.Tensor:
if state.ndim == 1:
return state.unsqueeze(0)
if state.ndim == 2:
return state
if state.ndim > 2:
return state.reshape(-1, state.shape[-1])[-1:].contiguous()
raise ValueError(f"observation.state must have at least 1 dimension, got shape {tuple(state.shape)}.")
def _action_training_batch(action: torch.Tensor, state_batch: torch.Tensor) -> torch.Tensor:
if action.ndim == 1:
return action.unsqueeze(0)
if action.ndim == 2:
# A single training sample uses (T, D) action chunks with a single (1, D) state reference.
# Batched callers may pass (B, D); keep that shape when the state batch makes it unambiguous.
if state_batch.shape[0] == action.shape[0] and state_batch.shape[0] > 1:
return action
return action.unsqueeze(0)
if action.ndim == 3:
return action
raise ValueError(f"action must be (D,), (T, D), (B, D), or (B, T, D), got {tuple(action.shape)}.")
def _unpadded_relative_action_vectors(relative_action: torch.Tensor, pad_mask: Any | None) -> torch.Tensor:
if pad_mask is None:
return relative_action.reshape(-1, relative_action.shape[-1])
keep = ~torch.as_tensor(pad_mask, dtype=torch.bool).cpu()
if relative_action.ndim == 3 and keep.ndim == 1 and relative_action.shape[0] == 1:
return relative_action[0, keep]
if relative_action.ndim == 3 and keep.ndim == 2 and tuple(keep.shape) == tuple(relative_action.shape[:2]):
return relative_action[keep]
if relative_action.ndim == 2 and keep.ndim == 1 and keep.numel() == relative_action.shape[0]:
return relative_action[keep]
return relative_action.reshape(-1, relative_action.shape[-1])
def _iter_action_state_training_samples(dataset: Any):
"""Yield action chunks, reference states, and action padding masks without decoding videos when possible."""
ensure_reader = getattr(dataset, "_ensure_reader", None)
if callable(ensure_reader):
reader = ensure_reader()
if reader.hf_dataset is None:
reader.load_and_activate()
delta_indices = getattr(reader, "delta_indices", None)
for idx in range(len(dataset)):
item = reader.hf_dataset[idx]
action = item.get(ACTION)
state = item.get(OBS_STATE)
pad_mask = None
if delta_indices is not None and ACTION in delta_indices:
ep_idx = _as_int(item["episode_index"])
abs_idx = _as_int(item["index"])
query_indices, padding = reader._get_query_indices(abs_idx, ep_idx)
action = reader._query_hf_dataset({ACTION: query_indices[ACTION]})[ACTION]
pad_mask = padding.get(f"{ACTION}_is_pad")
yield action, state, pad_mask
return
for idx in range(len(dataset)):
item = dataset[idx]
yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad")
def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None:
config_names = getattr(active_cfg, "action_feature_names", None)
if config_names is not None:
return list(config_names)
features = getattr(getattr(dataset, "meta", None), "features", {}) or {}
action_feature = features.get(ACTION) if isinstance(features, dict) else None
if isinstance(action_feature, dict):
names = action_feature.get("names")
else:
names = getattr(action_feature, "names", None)
return list(names) if names is not None else None
def _make_relative_action_training_stats(
dataset: Any,
*,
exclude_joints: list[str] | None,
action_names: list[str] | None,
) -> dict[str, dict[str, Any]]:
"""Return dataset stats whose action entry describes the relative action tensor used for training."""
from lerobot.datasets.compute_stats import RunningQuantileStats
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep, to_relative_actions
try:
dataset_len = len(dataset)
except TypeError as exc:
raise ValueError(
"Cannot compute relative action statistics for a dataset without a finite length. "
"Disable streaming or provide precomputed relative action statistics."
) from exc
if dataset_len == 0:
raise ValueError("Cannot compute relative action statistics for an empty dataset.")
stats = deepcopy(getattr(getattr(dataset, "meta", None), "stats", {}) or {})
running_stats = RunningQuantileStats()
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=list(exclude_joints or []),
action_names=action_names,
)
num_vectors = 0
for action_value, state_value, pad_mask in _iter_action_state_training_samples(dataset):
action = _to_float_tensor(action_value, key=ACTION)
state = _to_float_tensor(state_value, key=OBS_STATE)
state_batch = _state_reference_batch(state)
action_batch = _action_training_batch(action, state_batch)
if action_batch.shape[0] != state_batch.shape[0]:
if state_batch.shape[0] == 1:
state_batch = state_batch.expand(action_batch.shape[0], -1)
else:
raise ValueError(
"Cannot compute relative action statistics: action and state batch sizes differ "
f"({action_batch.shape[0]} vs {state_batch.shape[0]})."
)
relative_action = to_relative_actions(
action_batch,
state_batch,
relative_step._build_mask(action_batch.shape[-1]),
)
vectors = _unpadded_relative_action_vectors(relative_action, pad_mask)
if vectors.numel() == 0:
continue
vector_count = int(vectors.reshape(-1, vectors.shape[-1]).shape[0])
running_stats.update(vectors.numpy())
num_vectors += vector_count
if num_vectors < 2:
raise ValueError(
"Cannot compute relative action statistics from fewer than 2 unpadded action vectors."
)
stats[ACTION] = running_stats.get_statistics()
return stats
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
"""
@@ -292,10 +457,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
processor_stats = dataset.meta.stats
if not cfg.is_reward_model_training and getattr(active_cfg, "use_relative_actions", False):
if is_main_process:
logging.info("Computing relative-action output statistics for processor normalization")
processor_stats = _make_relative_action_training_stats(
dataset,
exclude_joints=getattr(active_cfg, "relative_exclude_joints", []),
action_names=_resolve_action_feature_names(active_cfg, dataset),
)
processor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
processor_kwargs["dataset_stats"] = dataset.meta.stats
processor_kwargs["dataset_stats"] = processor_stats
if cfg.is_reward_model_training:
processor_kwargs["dataset_meta"] = dataset.meta
@@ -304,7 +478,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor_overrides = {
"device_processor": {"device": device.type},
"normalizer_processor": {
"stats": dataset.meta.stats,
"stats": processor_stats,
"features": {**policy.config.input_features, **policy.config.output_features},
"norm_map": policy.config.normalization_mapping,
},
@@ -312,7 +486,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
}
postprocessor_overrides = {
"unnormalizer_processor": {
"stats": dataset.meta.stats,
"stats": processor_stats,
"features": policy.config.output_features,
"norm_map": policy.config.normalization_mapping,
},
@@ -321,7 +495,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor_overrides["relative_actions_processor"] = {
"enabled": True,
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
"action_names": getattr(active_cfg, "action_feature_names", None),
"action_names": _resolve_action_feature_names(active_cfg, dataset),
}
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
+102
View File
@@ -23,6 +23,7 @@ from unittest.mock import patch
import numpy as np
import pytest
import torch
from safetensors.torch import load_file
from torch import nn
from lerobot.configs import FeatureType, PolicyFeature
@@ -49,6 +50,7 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
)
from lerobot.scripts.lerobot_train import _make_relative_action_training_stats
from lerobot.types import TransitionKey
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
@@ -1754,6 +1756,106 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
def test_groot_n1_7_relative_action_training_processors_save_relative_action_stats(tmp_path):
input_features, output_features = _groot_features(state_dim=6, action_dim=6)
action_names = [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
]
config = GrootConfig(
input_features=input_features,
output_features=output_features,
device="cpu",
use_bf16=False,
action_decode_transform=None,
use_relative_actions=True,
relative_exclude_joints=["gripper"],
action_feature_names=action_names,
)
absolute_dataset_stats = {
OBS_STATE: {
"min": torch.tensor([-50.0, -60.0, -70.0, -80.0, -90.0, 0.0]),
"max": torch.tensor([50.0, 60.0, 70.0, 80.0, 90.0, 100.0]),
},
ACTION: {
"min": torch.tensor([-100.0, -110.0, -120.0, -130.0, -140.0, 0.0]),
"max": torch.tensor([100.0, 110.0, 120.0, 130.0, 140.0, 100.0]),
},
}
samples = [
{
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
ACTION: torch.tensor(
[
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
]
),
},
{
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
ACTION: torch.tensor(
[
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
]
),
},
]
class _RelativeStatsDataset:
meta = SimpleNamespace(
stats=absolute_dataset_stats,
features={ACTION: {"names": action_names}},
)
def __len__(self):
return len(samples)
def __getitem__(self, idx):
return samples[idx]
relative_dataset_stats = _make_relative_action_training_stats(
_RelativeStatsDataset(),
exclude_joints=["gripper"],
action_names=action_names,
)
expected_relative_action_stats = {
"min": torch.tensor([-2.0, -3.0, -4.0, -5.0, -6.0, 0.0]),
"max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
}
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats)
preprocessor.save_pretrained(tmp_path)
postprocessor.save_pretrained(tmp_path)
preprocessor_config = json.loads((tmp_path / "policy_preprocessor.json").read_text())
assert any(step.get("registry_name") == "relative_actions_processor" for step in preprocessor_config["steps"])
pack_entry = next(
step
for step in preprocessor_config["steps"]
if step.get("registry_name") == "groot_n1_7_pack_inputs_v1"
)
pack_state = load_file(tmp_path / pack_entry["state_file"])
torch.testing.assert_close(pack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
torch.testing.assert_close(pack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
postprocessor_config = json.loads((tmp_path / "policy_postprocessor.json").read_text())
assert any(step.get("registry_name") == "absolute_actions_processor" for step in postprocessor_config["steps"])
unpack_entry = next(
step
for step in postprocessor_config["steps"]
if step.get("registry_name", "").startswith("groot_action_unpack_unnormalize")
)
unpack_state = load_file(tmp_path / unpack_entry["state_file"])
torch.testing.assert_close(unpack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
torch.testing.assert_close(unpack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
from lerobot.policies.groot.groot_n1_7 import GR00TN17