mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-05 00:57:06 +00:00
Fix GROOT relative action training stats
This commit is contained in:
@@ -268,6 +268,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
# Groot-specific model parameters
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
|
||||
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
|
||||
@@ -324,6 +325,12 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Train on state-relative action chunks. The listed joints stay absolute, which is normally used
|
||||
# for gripper channels whose command frame is not the arm joint state.
|
||||
use_relative_actions: bool = False
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
@@ -358,6 +365,8 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
@@ -408,9 +417,9 @@ class GrootConfig(PreTrainedConfig):
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
|
||||
@@ -655,6 +655,14 @@ def make_groot_pre_post_processors(
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
relative_step: RelativeActionsProcessorStep | None = None
|
||||
if config.use_relative_actions:
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=list(config.relative_exclude_joints or []),
|
||||
action_names=config.action_feature_names,
|
||||
)
|
||||
input_steps.insert(2, relative_step)
|
||||
|
||||
if checkpoint_assets is not None and not checkpoint_has_stats and not has_modality_stats(padded_stats):
|
||||
raise ValueError(
|
||||
@@ -687,10 +695,10 @@ def make_groot_pre_post_processors(
|
||||
action_decode_transform=config.action_decode_transform,
|
||||
)
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
action_decode_step,
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
output_steps: list[ProcessorStep] = [action_decode_step]
|
||||
if relative_step is not None:
|
||||
output_steps.append(AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step))
|
||||
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
|
||||
@@ -22,6 +22,7 @@ import dataclasses
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pprint import pformat
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -49,6 +50,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -161,6 +163,169 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _as_int(value: Any) -> int:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return int(value.item())
|
||||
item = getattr(value, "item", None)
|
||||
if callable(item):
|
||||
return int(item())
|
||||
return int(value)
|
||||
|
||||
|
||||
def _to_float_tensor(value: Any, *, key: str) -> torch.Tensor:
|
||||
if value is None:
|
||||
raise ValueError(f"Cannot compute relative action statistics: sample is missing '{key}'.")
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().cpu().float()
|
||||
return torch.as_tensor(value, dtype=torch.float32)
|
||||
|
||||
|
||||
def _state_reference_batch(state: torch.Tensor) -> torch.Tensor:
|
||||
if state.ndim == 1:
|
||||
return state.unsqueeze(0)
|
||||
if state.ndim == 2:
|
||||
return state
|
||||
if state.ndim > 2:
|
||||
return state.reshape(-1, state.shape[-1])[-1:].contiguous()
|
||||
raise ValueError(f"observation.state must have at least 1 dimension, got shape {tuple(state.shape)}.")
|
||||
|
||||
|
||||
def _action_training_batch(action: torch.Tensor, state_batch: torch.Tensor) -> torch.Tensor:
|
||||
if action.ndim == 1:
|
||||
return action.unsqueeze(0)
|
||||
if action.ndim == 2:
|
||||
# A single training sample uses (T, D) action chunks with a single (1, D) state reference.
|
||||
# Batched callers may pass (B, D); keep that shape when the state batch makes it unambiguous.
|
||||
if state_batch.shape[0] == action.shape[0] and state_batch.shape[0] > 1:
|
||||
return action
|
||||
return action.unsqueeze(0)
|
||||
if action.ndim == 3:
|
||||
return action
|
||||
raise ValueError(f"action must be (D,), (T, D), (B, D), or (B, T, D), got {tuple(action.shape)}.")
|
||||
|
||||
|
||||
def _unpadded_relative_action_vectors(relative_action: torch.Tensor, pad_mask: Any | None) -> torch.Tensor:
|
||||
if pad_mask is None:
|
||||
return relative_action.reshape(-1, relative_action.shape[-1])
|
||||
|
||||
keep = ~torch.as_tensor(pad_mask, dtype=torch.bool).cpu()
|
||||
if relative_action.ndim == 3 and keep.ndim == 1 and relative_action.shape[0] == 1:
|
||||
return relative_action[0, keep]
|
||||
if relative_action.ndim == 3 and keep.ndim == 2 and tuple(keep.shape) == tuple(relative_action.shape[:2]):
|
||||
return relative_action[keep]
|
||||
if relative_action.ndim == 2 and keep.ndim == 1 and keep.numel() == relative_action.shape[0]:
|
||||
return relative_action[keep]
|
||||
return relative_action.reshape(-1, relative_action.shape[-1])
|
||||
|
||||
|
||||
def _iter_action_state_training_samples(dataset: Any):
|
||||
"""Yield action chunks, reference states, and action padding masks without decoding videos when possible."""
|
||||
|
||||
ensure_reader = getattr(dataset, "_ensure_reader", None)
|
||||
if callable(ensure_reader):
|
||||
reader = ensure_reader()
|
||||
if reader.hf_dataset is None:
|
||||
reader.load_and_activate()
|
||||
delta_indices = getattr(reader, "delta_indices", None)
|
||||
for idx in range(len(dataset)):
|
||||
item = reader.hf_dataset[idx]
|
||||
action = item.get(ACTION)
|
||||
state = item.get(OBS_STATE)
|
||||
pad_mask = None
|
||||
if delta_indices is not None and ACTION in delta_indices:
|
||||
ep_idx = _as_int(item["episode_index"])
|
||||
abs_idx = _as_int(item["index"])
|
||||
query_indices, padding = reader._get_query_indices(abs_idx, ep_idx)
|
||||
action = reader._query_hf_dataset({ACTION: query_indices[ACTION]})[ACTION]
|
||||
pad_mask = padding.get(f"{ACTION}_is_pad")
|
||||
yield action, state, pad_mask
|
||||
return
|
||||
|
||||
for idx in range(len(dataset)):
|
||||
item = dataset[idx]
|
||||
yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad")
|
||||
|
||||
|
||||
def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None:
|
||||
config_names = getattr(active_cfg, "action_feature_names", None)
|
||||
if config_names is not None:
|
||||
return list(config_names)
|
||||
|
||||
features = getattr(getattr(dataset, "meta", None), "features", {}) or {}
|
||||
action_feature = features.get(ACTION) if isinstance(features, dict) else None
|
||||
if isinstance(action_feature, dict):
|
||||
names = action_feature.get("names")
|
||||
else:
|
||||
names = getattr(action_feature, "names", None)
|
||||
return list(names) if names is not None else None
|
||||
|
||||
|
||||
def _make_relative_action_training_stats(
|
||||
dataset: Any,
|
||||
*,
|
||||
exclude_joints: list[str] | None,
|
||||
action_names: list[str] | None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Return dataset stats whose action entry describes the relative action tensor used for training."""
|
||||
|
||||
from lerobot.datasets.compute_stats import RunningQuantileStats
|
||||
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep, to_relative_actions
|
||||
|
||||
try:
|
||||
dataset_len = len(dataset)
|
||||
except TypeError as exc:
|
||||
raise ValueError(
|
||||
"Cannot compute relative action statistics for a dataset without a finite length. "
|
||||
"Disable streaming or provide precomputed relative action statistics."
|
||||
) from exc
|
||||
|
||||
if dataset_len == 0:
|
||||
raise ValueError("Cannot compute relative action statistics for an empty dataset.")
|
||||
|
||||
stats = deepcopy(getattr(getattr(dataset, "meta", None), "stats", {}) or {})
|
||||
running_stats = RunningQuantileStats()
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=list(exclude_joints or []),
|
||||
action_names=action_names,
|
||||
)
|
||||
num_vectors = 0
|
||||
|
||||
for action_value, state_value, pad_mask in _iter_action_state_training_samples(dataset):
|
||||
action = _to_float_tensor(action_value, key=ACTION)
|
||||
state = _to_float_tensor(state_value, key=OBS_STATE)
|
||||
state_batch = _state_reference_batch(state)
|
||||
action_batch = _action_training_batch(action, state_batch)
|
||||
if action_batch.shape[0] != state_batch.shape[0]:
|
||||
if state_batch.shape[0] == 1:
|
||||
state_batch = state_batch.expand(action_batch.shape[0], -1)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot compute relative action statistics: action and state batch sizes differ "
|
||||
f"({action_batch.shape[0]} vs {state_batch.shape[0]})."
|
||||
)
|
||||
|
||||
relative_action = to_relative_actions(
|
||||
action_batch,
|
||||
state_batch,
|
||||
relative_step._build_mask(action_batch.shape[-1]),
|
||||
)
|
||||
vectors = _unpadded_relative_action_vectors(relative_action, pad_mask)
|
||||
if vectors.numel() == 0:
|
||||
continue
|
||||
vector_count = int(vectors.reshape(-1, vectors.shape[-1]).shape[0])
|
||||
running_stats.update(vectors.numpy())
|
||||
num_vectors += vector_count
|
||||
|
||||
if num_vectors < 2:
|
||||
raise ValueError(
|
||||
"Cannot compute relative action statistics from fewer than 2 unpadded action vectors."
|
||||
)
|
||||
|
||||
stats[ACTION] = running_stats.get_statistics()
|
||||
return stats
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
"""
|
||||
@@ -292,10 +457,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
processor_stats = dataset.meta.stats
|
||||
if not cfg.is_reward_model_training and getattr(active_cfg, "use_relative_actions", False):
|
||||
if is_main_process:
|
||||
logging.info("Computing relative-action output statistics for processor normalization")
|
||||
processor_stats = _make_relative_action_training_stats(
|
||||
dataset,
|
||||
exclude_joints=getattr(active_cfg, "relative_exclude_joints", []),
|
||||
action_names=_resolve_action_feature_names(active_cfg, dataset),
|
||||
)
|
||||
|
||||
processor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
processor_kwargs["dataset_stats"] = processor_stats
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
@@ -304,7 +478,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"stats": processor_stats,
|
||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
@@ -312,7 +486,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
}
|
||||
postprocessor_overrides = {
|
||||
"unnormalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"stats": processor_stats,
|
||||
"features": policy.config.output_features,
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
@@ -321,7 +495,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": getattr(active_cfg, "action_feature_names", None),
|
||||
"action_names": _resolve_action_feature_names(active_cfg, dataset),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
|
||||
@@ -23,6 +23,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
@@ -49,6 +50,7 @@ from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
from lerobot.scripts.lerobot_train import _make_relative_action_training_stats
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
@@ -1754,6 +1756,106 @@ def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stat
|
||||
|
||||
|
||||
|
||||
def test_groot_n1_7_relative_action_training_processors_save_relative_action_stats(tmp_path):
|
||||
input_features, output_features = _groot_features(state_dim=6, action_dim=6)
|
||||
action_names = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
config = GrootConfig(
|
||||
input_features=input_features,
|
||||
output_features=output_features,
|
||||
device="cpu",
|
||||
use_bf16=False,
|
||||
action_decode_transform=None,
|
||||
use_relative_actions=True,
|
||||
relative_exclude_joints=["gripper"],
|
||||
action_feature_names=action_names,
|
||||
)
|
||||
absolute_dataset_stats = {
|
||||
OBS_STATE: {
|
||||
"min": torch.tensor([-50.0, -60.0, -70.0, -80.0, -90.0, 0.0]),
|
||||
"max": torch.tensor([50.0, 60.0, 70.0, 80.0, 90.0, 100.0]),
|
||||
},
|
||||
ACTION: {
|
||||
"min": torch.tensor([-100.0, -110.0, -120.0, -130.0, -140.0, 0.0]),
|
||||
"max": torch.tensor([100.0, 110.0, 120.0, 130.0, 140.0, 100.0]),
|
||||
},
|
||||
}
|
||||
samples = [
|
||||
{
|
||||
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
|
||||
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
|
||||
]
|
||||
),
|
||||
},
|
||||
{
|
||||
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
|
||||
]
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
class _RelativeStatsDataset:
|
||||
meta = SimpleNamespace(
|
||||
stats=absolute_dataset_stats,
|
||||
features={ACTION: {"names": action_names}},
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return samples[idx]
|
||||
|
||||
relative_dataset_stats = _make_relative_action_training_stats(
|
||||
_RelativeStatsDataset(),
|
||||
exclude_joints=["gripper"],
|
||||
action_names=action_names,
|
||||
)
|
||||
expected_relative_action_stats = {
|
||||
"min": torch.tensor([-2.0, -3.0, -4.0, -5.0, -6.0, 0.0]),
|
||||
"max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats)
|
||||
preprocessor.save_pretrained(tmp_path)
|
||||
postprocessor.save_pretrained(tmp_path)
|
||||
|
||||
preprocessor_config = json.loads((tmp_path / "policy_preprocessor.json").read_text())
|
||||
assert any(step.get("registry_name") == "relative_actions_processor" for step in preprocessor_config["steps"])
|
||||
pack_entry = next(
|
||||
step
|
||||
for step in preprocessor_config["steps"]
|
||||
if step.get("registry_name") == "groot_n1_7_pack_inputs_v1"
|
||||
)
|
||||
pack_state = load_file(tmp_path / pack_entry["state_file"])
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
|
||||
|
||||
postprocessor_config = json.loads((tmp_path / "policy_postprocessor.json").read_text())
|
||||
assert any(step.get("registry_name") == "absolute_actions_processor" for step in postprocessor_config["steps"])
|
||||
unpack_entry = next(
|
||||
step
|
||||
for step in postprocessor_config["steps"]
|
||||
if step.get("registry_name", "").startswith("groot_action_unpack_unnormalize")
|
||||
)
|
||||
unpack_state = load_file(tmp_path / unpack_entry["state_file"])
|
||||
torch.testing.assert_close(unpack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
|
||||
torch.testing.assert_close(unpack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
|
||||
|
||||
|
||||
def test_groot_policy_selects_n1_7_model_class(monkeypatch):
|
||||
from lerobot.policies.groot.groot_n1_7 import GR00TN17
|
||||
|
||||
|
||||
Reference in New Issue
Block a user