relative experiment

This commit is contained in:
Steven Palma
2026-06-15 16:38:36 +02:00
parent 13ed657056
commit 05a9ca274b
3 changed files with 304 additions and 4 deletions
@@ -389,6 +389,40 @@ class GrootConfig(PreTrainedConfig):
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
+116 -1
View File
@@ -54,6 +54,98 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
@@ -93,6 +185,25 @@ class GrootPolicy(PreTrainedPolicy):
transformers_loading_kwargs={"trust_remote_code": True},
)
# Inference-only override for the number of flow-matching denoising steps. The action
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model
def reset(self):
@@ -260,7 +371,11 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
# An explicit config override caps the open-loop horizon (inference cadence), overriding
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
+154 -3
View File
@@ -47,6 +47,8 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
batch_to_transition,
policy_action_to_transition,
to_absolute_actions,
to_relative_actions,
transition_to_batch,
transition_to_policy_action,
)
@@ -117,6 +119,39 @@ class _GrootN17CheckpointProcessorAssets:
use_albumentations: bool
def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None:
"""Resolve a base model path to a local snapshot dir holding its sidecar JSONs.
``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a
bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7
checkpoint and the processor would fall back to LeRobot default image geometry instead of the
checkpoint's processor_config.json geometry. When the path is not already a local dir, this
downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure
returns the original string unchanged. Only used on the fresh-build (training) path; inference
loads the serialized processor, so no per-inference network call is added.
"""
if base_model_path is None:
return None
if Path(base_model_path).expanduser().is_dir():
return base_model_path
try:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
base_model_path,
repo_type="model",
allow_patterns=["*.json"],
)
logging.debug(
"Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.",
base_model_path,
local_dir,
)
return local_dir
except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure)
return base_model_path
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
"""Load N1.7 processor settings from checkpoint sidecar JSON files.
@@ -124,10 +159,11 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
can keep using caller-provided dataset stats and config values.
"""
if not is_raw_groot_n1_7_checkpoint(config.base_model_path):
resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path)
if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path):
return None
checkpoint_path = Path(config.base_model_path).expanduser()
checkpoint_path = Path(resolved_base_model_path).expanduser()
processor_config = _read_json(checkpoint_path / "processor_config.json")
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
@@ -452,6 +488,40 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
return any(bool(modality_stats) for modality_stats in stats.values())
def _build_relative_action_mask(
action_dim: int,
exclude_joints: list[str] | None,
action_names: list[str] | None,
) -> list[bool]:
"""Build the per-dim relative-action mask (True = convert to relative, False = keep absolute).
Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded
(kept absolute) by case-insensitive token match against ``action_names``.
When ``action_names`` is None we cannot identify the gripper, so this returns all-True
(every dim treated as relative). The user should ensure ``config.action_feature_names`` is
populated (the factory does this from dataset meta) so the gripper can be kept absolute;
arm-relative still works either way, but a missing-name gripper would be treated as relative.
"""
if not exclude_joints or action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask: list[bool] = []
for name in action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
@@ -653,8 +723,15 @@ def _reconnect_groot_n1_7_pack_decode_steps(
if pack_step is None:
return
# Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link:
# GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path
# (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after
# deserialization.
for step in postprocessor.steps:
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
if (
isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep))
and step.pack_step is None
):
step.pack_step = pack_step
@@ -732,6 +809,9 @@ def make_groot_pre_post_processors(
video_modality_keys=video_modality_keys,
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
)
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
@@ -791,6 +871,10 @@ def make_groot_pre_post_processors(
stats=padded_stats,
normalize_min_max=True,
clip_normalized_action=True,
use_relative_actions=config.use_relative_actions,
relative_exclude_joints=config.relative_exclude_joints,
action_feature_names=config.action_feature_names,
pack_step=pack_step,
)
else:
action_decode_step = GrootN17ActionDecodeStep(
@@ -1087,7 +1171,14 @@ class GrootN17PackInputsStep(ProcessorStep):
video_modality_keys: list[str] | None = None
raw_stats: dict[str, Any] | None = None
modality_config: dict[str, Any] | None = None
# Opt-in relative-action support: convert absolute->relative actions inside this pack step
# (training) using the cached raw reference state, keeping excluded joints (e.g. gripper)
# absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode.
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
_last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False)
_warned_image_keys: bool = field(default=False, init=False, repr=False)
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
@@ -1229,6 +1320,9 @@ class GrootN17PackInputsStep(ProcessorStep):
formalize_language=self.formalize_language,
)
# Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached
# regardless of whether an action is present so inference caches it too for decode.
relative_reference_state: torch.Tensor | None = None
if OBS_STATE in obs:
state = obs[OBS_STATE]
if state.dim() != 2:
@@ -1237,6 +1331,9 @@ class GrootN17PackInputsStep(ProcessorStep):
if dim > self.max_state_dim:
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
_cache_raw_state(state)
if self.use_relative_actions:
relative_reference_state = state.detach().clone()
self._last_reference_state = relative_reference_state
if self.normalize_min_max:
state = _min_max_norm(state, OBS_STATE)
state = state.unsqueeze(1)
@@ -1259,6 +1356,19 @@ class GrootN17PackInputsStep(ProcessorStep):
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
if dim > self.max_action_dim:
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
# Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g.
# gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T.
if self.use_relative_actions:
if relative_reference_state is None:
raise RuntimeError(
"GrootN17PackInputsStep.use_relative_actions requires observation.state "
"(OBS_STATE) to be present alongside the action to build the relative "
"reference, but no state was found in this transition."
)
mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_relative_actions(action, relative_reference_state, mask)
if self.normalize_min_max:
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
action = flat.view(bsz, horizon, dim)
@@ -1322,6 +1432,9 @@ class GrootN17PackInputsStep(ProcessorStep):
"video_modality_keys": self.video_modality_keys,
"raw_stats": self.raw_stats,
"modality_config": self.modality_config,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
}
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
@@ -1329,6 +1442,11 @@ class GrootN17PackInputsStep(ProcessorStep):
return self._last_raw_state
def get_cached_reference_state(self) -> torch.Tensor | None:
"""Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion."""
return self._last_reference_state
def state_dict(self) -> dict[str, torch.Tensor]:
if not self.stats:
return {}
@@ -1803,6 +1921,13 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
clip_normalized_action: bool = False
libero_gripper_action: bool = False
libero_gripper_binarize: bool = True
# Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the
# min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute
# using the reference state cached by the linked pack_step (re-linked on reload).
use_relative_actions: bool = False
relative_exclude_joints: list[str] = field(default_factory=list)
action_feature_names: list[str] | None = None
pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
@@ -1842,6 +1967,29 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
inv = (action + 1.0) * 0.5 * safe_denom + min_v
action = torch.where(mask, inv, min_v)
# Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the
# reference state cached by the linked pack step. The link is restored on reload by
# _reconnect_groot_n1_7_pack_decode_steps.
if self.use_relative_actions:
if self.pack_step is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked "
"GrootN17PackInputsStep to read the cached reference state, but pack_step is None. "
"Build both pipelines through make_groot_pre_post_processors (or load them together "
"via make_groot_pre_post_processors_from_pretrained)."
)
ref = self.pack_step.get_cached_reference_state()
if ref is None:
raise RuntimeError(
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state "
"cached by its connected GrootN17PackInputsStep to convert relative actions back to "
"absolute. Run the preprocessor on an observation before decoding actions."
)
relative_mask = _build_relative_action_mask(
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
)
action = to_absolute_actions(action, ref, relative_mask)
if self.libero_gripper_action and action.shape[-1] >= 7:
gripper = action[..., -1]
if self.libero_gripper_binarize:
@@ -1869,6 +2017,9 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
"clip_normalized_action": self.clip_normalized_action,
"libero_gripper_action": self.libero_gripper_action,
"libero_gripper_binarize": self.libero_gripper_binarize,
"use_relative_actions": self.use_relative_actions,
"relative_exclude_joints": self.relative_exclude_joints,
"action_feature_names": self.action_feature_names,
}
def state_dict(self) -> dict[str, torch.Tensor]: