diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 9ff6c11e6..2324ef9be 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -389,6 +389,40 @@ class GrootConfig(PreTrainedConfig): # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') embodiment_tag: str = "new_embodiment" + # Inference-only override for the number of flow-matching denoising steps used to decode an + # action chunk. None = use the model checkpoint default (currently 4). Higher values trade + # inference speed for action quality; applied at base-model load via _create_groot_model. + num_inference_timesteps: int | None = None + + # If set, caps the number of open-loop actions executed before replanning (inference cadence). + # Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps. + execution_horizon: int | None = None + + # Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target + # embodiment slot at base-model build (in _create_groot_model), to warm-start a cold + # 'new_embodiment' slot. Accepts an embodiment name (e.g. + # 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh + # base-model build (so it applies during lerobot-train, which uses __init__ not + # from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten. + warm_start_embodiment_slot: int | str | None = None + + # Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native). + # When True, GR00T converts absolute->relative inside its own pack step (training) and + # reconstructs absolute inside its own flat decode step (inference), using a cached + # reference state. The dataset stays absolute; compute relative ACTION stats with + # `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints + # "['gripper']"` (this only rewrites stats, not actions). + use_relative_actions: bool = False + + # Joint names kept absolute (not converted to relative) when use_relative_actions is True. + # Case-insensitive token match against action_feature_names. + relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"]) + + # Action dimension names from dataset metadata; auto-populated by the factory from dataset + # meta (see factory.py:528). Used to build the relative-action mask so the gripper can be + # identified and kept absolute. When None, the gripper cannot be identified. + action_feature_names: list[str] | None = None + # Fine-tuning control arguments # Whether to fine-tune the llm backbone diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 8ecc33a88..c7bae4561 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -54,6 +54,98 @@ logger = logging.getLogger(__name__) T = TypeVar("T", bound="GrootPolicy") +def _resolve_embodiment_id(value: int | str) -> int: + """Resolve an embodiment id from an int or an N1.7 embodiment name. + + Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10). + Raises ValueError listing the known keys if the name is unknown. + """ + from .processor_groot import N1_7_EMBODIMENT_MAPPING + + if isinstance(value, bool): # bool is a subclass of int; reject it explicitly. + raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.") + if isinstance(value, int): + return value + if value in N1_7_EMBODIMENT_MAPPING: + return N1_7_EMBODIMENT_MAPPING[value] + raise ValueError( + f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: " + f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}." + ) + + +def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None: + """Copy category-specific action-head weights from one embodiment slot to another. + + Used at base-model load (training only) to warm-start a cold target embodiment slot + (e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b`` + parameters across every CategorySpecificLinear in the action head's state encoder, + action encoder, and action decoder. No-ops (with a logged warning) if the ids are out + of range or identical. + """ + if source_id == target_id: + logger.warning( + "GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; " + "skipping (nothing to copy).", + source_id, + ) + return + + action_head = getattr(model, "action_head", None) + if action_head is None: + logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.") + return + + # Each entry is (submodule, [CategorySpecificLinear attribute names]). + linear_groups = [ + (getattr(action_head, "state_encoder", None), ["layer1", "layer2"]), + (getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]), + (getattr(action_head, "action_decoder", None), ["layer1", "layer2"]), + ] + + copied: list[str] = [] + with torch.no_grad(): + for submodule, attr_names in linear_groups: + if submodule is None: + continue + submodule_name = type(submodule).__name__ + for attr_name in attr_names: + lin = getattr(submodule, attr_name, None) + if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"): + continue + num_categories = lin.W.shape[0] + if not (0 <= source_id < num_categories and 0 <= target_id < num_categories): + logger.warning( + "GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range " + "for %s.%s (num_categories=%d); skipping this layer.", + source_id, + target_id, + submodule_name, + attr_name, + num_categories, + ) + continue + lin.W.data[target_id] = lin.W.data[source_id].clone() + lin.b.data[target_id] = lin.b.data[source_id].clone() + copied.append(f"{submodule_name}.{attr_name}") + + if copied: + logger.info( + "GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d " + "to slot %d for: %s.", + source_id, + target_id, + ", ".join(copied), + ) + else: + logger.warning( + "GR00T warm_start_embodiment_slot: no action-head weights were copied " + "(source_id=%d, target_id=%d).", + source_id, + target_id, + ) + + class GrootPolicy(PreTrainedPolicy): """Wrapper around external Groot model for LeRobot integration.""" @@ -93,6 +185,25 @@ class GrootPolicy(PreTrainedPolicy): transformers_loading_kwargs={"trust_remote_code": True}, ) + # Inference-only override for the number of flow-matching denoising steps. The action + # head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the + # t schedule adapt automatically. + if self.config.num_inference_timesteps is not None: + n = int(self.config.num_inference_timesteps) + model.config.num_inference_timesteps = n + model.action_head.num_inference_timesteps = n + + # Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained + # slot's action-head weights. Done here (not in from_pretrained) so it applies on every + # fresh base-model build -- training via make_policy instantiates GrootPolicy(config) + # directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go + # through from_pretrained. On a fine-tuned checkpoint reload this also runs but is + # immediately overwritten by the loaded state_dict, so it is a harmless no-op there. + if self.config.warm_start_embodiment_slot is not None: + source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot) + target_id = _resolve_embodiment_id(self.config.embodiment_tag) + _warm_start_embodiment_slot(model, source_id, target_id) + return model def reset(self): @@ -260,7 +371,11 @@ class GrootPolicy(PreTrainedPolicy): horizons.append(checkpoint_action_horizon) if execution_horizon is not None: horizons.append(execution_horizon) - return min(horizons) + # An explicit config override caps the open-loop horizon (inference cadence), overriding + # the value inferred from the checkpoint/embodiment. + if self.config.execution_horizon is not None: + horizons.append(max(1, int(self.config.execution_horizon))) + return max(1, min(horizons)) def _resolve_prediction_horizon(self, actions: Tensor) -> int: """Return the policy-facing action horizon for a native GR00T prediction.""" diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index da1ef16ac..a01fb95bc 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -47,6 +47,8 @@ from lerobot.processor import ( RenameObservationsProcessorStep, batch_to_transition, policy_action_to_transition, + to_absolute_actions, + to_relative_actions, transition_to_batch, transition_to_policy_action, ) @@ -117,6 +119,39 @@ class _GrootN17CheckpointProcessorAssets: use_albumentations: bool +def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None: + """Resolve a base model path to a local snapshot dir holding its sidecar JSONs. + + ``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a + bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7 + checkpoint and the processor would fall back to LeRobot default image geometry instead of the + checkpoint's processor_config.json geometry. When the path is not already a local dir, this + downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure + returns the original string unchanged. Only used on the fresh-build (training) path; inference + loads the serialized processor, so no per-inference network call is added. + """ + if base_model_path is None: + return None + if Path(base_model_path).expanduser().is_dir(): + return base_model_path + try: + from huggingface_hub import snapshot_download + + local_dir = snapshot_download( + base_model_path, + repo_type="model", + allow_patterns=["*.json"], + ) + logging.debug( + "Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.", + base_model_path, + local_dir, + ) + return local_dir + except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure) + return base_model_path + + def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None: """Load N1.7 processor settings from checkpoint sidecar JSON files. @@ -124,10 +159,11 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec can keep using caller-provided dataset stats and config values. """ - if not is_raw_groot_n1_7_checkpoint(config.base_model_path): + resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path) + if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path): return None - checkpoint_path = Path(config.base_model_path).expanduser() + checkpoint_path = Path(resolved_base_model_path).expanduser() processor_config = _read_json(checkpoint_path / "processor_config.json") processor_kwargs = processor_config.get("processor_kwargs", {}) if not isinstance(processor_kwargs, dict): @@ -452,6 +488,40 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool: return any(bool(modality_stats) for modality_stats in stats.values()) +def _build_relative_action_mask( + action_dim: int, + exclude_joints: list[str] | None, + action_names: list[str] | None, +) -> list[bool]: + """Build the per-dim relative-action mask (True = convert to relative, False = keep absolute). + + Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded + (kept absolute) by case-insensitive token match against ``action_names``. + + When ``action_names`` is None we cannot identify the gripper, so this returns all-True + (every dim treated as relative). The user should ensure ``config.action_feature_names`` is + populated (the factory does this from dataset meta) so the gripper can be kept absolute; + arm-relative still works either way, but a missing-name gripper would be treated as relative. + """ + if not exclude_joints or action_names is None: + return [True] * action_dim + + exclude_tokens = [str(name).lower() for name in exclude_joints if name] + if not exclude_tokens: + return [True] * action_dim + + mask: list[bool] = [] + for name in action_names[:action_dim]: + action_name = str(name).lower() + is_excluded = any(token == action_name or token in action_name for token in exclude_tokens) + mask.append(not is_excluded) + + if len(mask) < action_dim: + mask.extend([True] * (action_dim - len(mask))) + + return mask + + # GR00T normalizes state/action inside its own processor steps and so deliberately has no # NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is # IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys @@ -653,8 +723,15 @@ def _reconnect_groot_n1_7_pack_decode_steps( if pack_step is None: return + # Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link: + # GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path + # (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after + # deserialization. for step in postprocessor.steps: - if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None: + if ( + isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep)) + and step.pack_step is None + ): step.pack_step = pack_step @@ -732,6 +809,9 @@ def make_groot_pre_post_processors( video_modality_keys=video_modality_keys, raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None, modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None, + use_relative_actions=config.use_relative_actions, + relative_exclude_joints=config.relative_exclude_joints, + action_feature_names=config.action_feature_names, ) # Resolve the image preprocessing geometry. Honor the checkpoint's processor_config @@ -791,6 +871,10 @@ def make_groot_pre_post_processors( stats=padded_stats, normalize_min_max=True, clip_normalized_action=True, + use_relative_actions=config.use_relative_actions, + relative_exclude_joints=config.relative_exclude_joints, + action_feature_names=config.action_feature_names, + pack_step=pack_step, ) else: action_decode_step = GrootN17ActionDecodeStep( @@ -1087,7 +1171,14 @@ class GrootN17PackInputsStep(ProcessorStep): video_modality_keys: list[str] | None = None raw_stats: dict[str, Any] | None = None modality_config: dict[str, Any] | None = None + # Opt-in relative-action support: convert absolute->relative actions inside this pack step + # (training) using the cached raw reference state, keeping excluded joints (e.g. gripper) + # absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode. + use_relative_actions: bool = False + relative_exclude_joints: list[str] = field(default_factory=list) + action_feature_names: list[str] | None = None _last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False) + _last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False) _warned_image_keys: bool = field(default=False, init=False, repr=False) def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]: @@ -1229,6 +1320,9 @@ class GrootN17PackInputsStep(ProcessorStep): formalize_language=self.formalize_language, ) + # Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached + # regardless of whether an action is present so inference caches it too for decode. + relative_reference_state: torch.Tensor | None = None if OBS_STATE in obs: state = obs[OBS_STATE] if state.dim() != 2: @@ -1237,6 +1331,9 @@ class GrootN17PackInputsStep(ProcessorStep): if dim > self.max_state_dim: raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.") _cache_raw_state(state) + if self.use_relative_actions: + relative_reference_state = state.detach().clone() + self._last_reference_state = relative_reference_state if self.normalize_min_max: state = _min_max_norm(state, OBS_STATE) state = state.unsqueeze(1) @@ -1259,6 +1356,19 @@ class GrootN17PackInputsStep(ProcessorStep): raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.") if dim > self.max_action_dim: raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.") + # Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g. + # gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T. + if self.use_relative_actions: + if relative_reference_state is None: + raise RuntimeError( + "GrootN17PackInputsStep.use_relative_actions requires observation.state " + "(OBS_STATE) to be present alongside the action to build the relative " + "reference, but no state was found in this transition." + ) + mask = _build_relative_action_mask( + action.shape[-1], self.relative_exclude_joints, self.action_feature_names + ) + action = to_relative_actions(action, relative_reference_state, mask) if self.normalize_min_max: flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION) action = flat.view(bsz, horizon, dim) @@ -1322,6 +1432,9 @@ class GrootN17PackInputsStep(ProcessorStep): "video_modality_keys": self.video_modality_keys, "raw_stats": self.raw_stats, "modality_config": self.modality_config, + "use_relative_actions": self.use_relative_actions, + "relative_exclude_joints": self.relative_exclude_joints, + "action_feature_names": self.action_feature_names, } def get_cached_raw_state(self) -> dict[str, np.ndarray] | None: @@ -1329,6 +1442,11 @@ class GrootN17PackInputsStep(ProcessorStep): return self._last_raw_state + def get_cached_reference_state(self) -> torch.Tensor | None: + """Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion.""" + + return self._last_reference_state + def state_dict(self) -> dict[str, torch.Tensor]: if not self.stats: return {} @@ -1803,6 +1921,13 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): clip_normalized_action: bool = False libero_gripper_action: bool = False libero_gripper_binarize: bool = True + # Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the + # min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute + # using the reference state cached by the linked pack_step (re-linked on reload). + use_relative_actions: bool = False + relative_exclude_joints: list[str] = field(default_factory=list) + action_feature_names: list[str] | None = None + pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False) def __call__(self, transition: EnvTransition) -> EnvTransition: # Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model) @@ -1842,6 +1967,29 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): inv = (action + 1.0) * 0.5 * safe_denom + min_v action = torch.where(mask, inv, min_v) + # Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the + # reference state cached by the linked pack step. The link is restored on reload by + # _reconnect_groot_n1_7_pack_decode_steps. + if self.use_relative_actions: + if self.pack_step is None: + raise RuntimeError( + "GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked " + "GrootN17PackInputsStep to read the cached reference state, but pack_step is None. " + "Build both pipelines through make_groot_pre_post_processors (or load them together " + "via make_groot_pre_post_processors_from_pretrained)." + ) + ref = self.pack_step.get_cached_reference_state() + if ref is None: + raise RuntimeError( + "GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state " + "cached by its connected GrootN17PackInputsStep to convert relative actions back to " + "absolute. Run the preprocessor on an observation before decoding actions." + ) + relative_mask = _build_relative_action_mask( + action.shape[-1], self.relative_exclude_joints, self.action_feature_names + ) + action = to_absolute_actions(action, ref, relative_mask) + if self.libero_gripper_action and action.shape[-1] >= 7: gripper = action[..., -1] if self.libero_gripper_binarize: @@ -1869,6 +2017,9 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): "clip_normalized_action": self.clip_normalized_action, "libero_gripper_action": self.libero_gripper_action, "libero_gripper_binarize": self.libero_gripper_binarize, + "use_relative_actions": self.use_relative_actions, + "relative_exclude_joints": self.relative_exclude_joints, + "action_feature_names": self.action_feature_names, } def state_dict(self) -> dict[str, torch.Tensor]: