mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 93e29b0cfc |
@@ -42,12 +42,8 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
|
||||
# falls back to these when a checkpoint ships no image sizing in its processor_config
|
||||
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
|
||||
# are resized to the expected resolution instead of being patchified at full camera
|
||||
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
|
||||
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
@@ -389,40 +385,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
# Inference-only override for the number of flow-matching denoising steps used to decode an
|
||||
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
|
||||
# inference speed for action quality; applied at base-model load via _create_groot_model.
|
||||
num_inference_timesteps: int | None = None
|
||||
|
||||
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
|
||||
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
|
||||
execution_horizon: int | None = None
|
||||
|
||||
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
|
||||
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
|
||||
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
|
||||
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
|
||||
# base-model build (so it applies during lerobot-train, which uses __init__ not
|
||||
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
|
||||
warm_start_embodiment_slot: int | str | None = None
|
||||
|
||||
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
|
||||
# When True, GR00T converts absolute->relative inside its own pack step (training) and
|
||||
# reconstructs absolute inside its own flat decode step (inference), using a cached
|
||||
# reference state. The dataset stays absolute; compute relative ACTION stats with
|
||||
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
|
||||
# "['gripper']"` (this only rewrites stats, not actions).
|
||||
use_relative_actions: bool = False
|
||||
|
||||
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
|
||||
# Case-insensitive token match against action_feature_names.
|
||||
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
|
||||
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
|
||||
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
|
||||
# identified and kept absolute. When None, the gripper cannot be identified.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Fine-tuning control arguments
|
||||
|
||||
# Whether to fine-tune the llm backbone
|
||||
|
||||
@@ -54,98 +54,6 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
|
||||
def _resolve_embodiment_id(value: int | str) -> int:
|
||||
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
|
||||
|
||||
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
|
||||
Raises ValueError listing the known keys if the name is unknown.
|
||||
"""
|
||||
from .processor_groot import N1_7_EMBODIMENT_MAPPING
|
||||
|
||||
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
|
||||
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if value in N1_7_EMBODIMENT_MAPPING:
|
||||
return N1_7_EMBODIMENT_MAPPING[value]
|
||||
raise ValueError(
|
||||
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
|
||||
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
|
||||
"""Copy category-specific action-head weights from one embodiment slot to another.
|
||||
|
||||
Used at base-model load (training only) to warm-start a cold target embodiment slot
|
||||
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
|
||||
parameters across every CategorySpecificLinear in the action head's state encoder,
|
||||
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
|
||||
of range or identical.
|
||||
"""
|
||||
if source_id == target_id:
|
||||
logger.warning(
|
||||
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
|
||||
"skipping (nothing to copy).",
|
||||
source_id,
|
||||
)
|
||||
return
|
||||
|
||||
action_head = getattr(model, "action_head", None)
|
||||
if action_head is None:
|
||||
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
|
||||
return
|
||||
|
||||
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
|
||||
linear_groups = [
|
||||
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
|
||||
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
|
||||
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
|
||||
]
|
||||
|
||||
copied: list[str] = []
|
||||
with torch.no_grad():
|
||||
for submodule, attr_names in linear_groups:
|
||||
if submodule is None:
|
||||
continue
|
||||
submodule_name = type(submodule).__name__
|
||||
for attr_name in attr_names:
|
||||
lin = getattr(submodule, attr_name, None)
|
||||
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
|
||||
continue
|
||||
num_categories = lin.W.shape[0]
|
||||
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
|
||||
logger.warning(
|
||||
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
|
||||
"for %s.%s (num_categories=%d); skipping this layer.",
|
||||
source_id,
|
||||
target_id,
|
||||
submodule_name,
|
||||
attr_name,
|
||||
num_categories,
|
||||
)
|
||||
continue
|
||||
lin.W.data[target_id] = lin.W.data[source_id].clone()
|
||||
lin.b.data[target_id] = lin.b.data[source_id].clone()
|
||||
copied.append(f"{submodule_name}.{attr_name}")
|
||||
|
||||
if copied:
|
||||
logger.info(
|
||||
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
|
||||
"to slot %d for: %s.",
|
||||
source_id,
|
||||
target_id,
|
||||
", ".join(copied),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
|
||||
"(source_id=%d, target_id=%d).",
|
||||
source_id,
|
||||
target_id,
|
||||
)
|
||||
|
||||
|
||||
class GrootPolicy(PreTrainedPolicy):
|
||||
"""Wrapper around external Groot model for LeRobot integration."""
|
||||
|
||||
@@ -185,25 +93,6 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
# Inference-only override for the number of flow-matching denoising steps. The action
|
||||
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
|
||||
# t schedule adapt automatically.
|
||||
if self.config.num_inference_timesteps is not None:
|
||||
n = int(self.config.num_inference_timesteps)
|
||||
model.config.num_inference_timesteps = n
|
||||
model.action_head.num_inference_timesteps = n
|
||||
|
||||
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
|
||||
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
|
||||
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
|
||||
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
|
||||
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
|
||||
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
|
||||
if self.config.warm_start_embodiment_slot is not None:
|
||||
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
|
||||
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
|
||||
_warm_start_embodiment_slot(model, source_id, target_id)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
@@ -371,11 +260,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
if execution_horizon is not None:
|
||||
horizons.append(execution_horizon)
|
||||
# An explicit config override caps the open-loop horizon (inference cadence), overriding
|
||||
# the value inferred from the checkpoint/embodiment.
|
||||
if self.config.execution_horizon is not None:
|
||||
horizons.append(max(1, int(self.config.execution_horizon)))
|
||||
return max(1, min(horizons))
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
@@ -543,16 +428,6 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
|
||||
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
|
||||
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
|
||||
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
|
||||
from .processor_groot import _GROOT_REF_HOLDER_KEY
|
||||
|
||||
holder = batch.get(_GROOT_REF_HOLDER_KEY)
|
||||
if holder is not None:
|
||||
holder.freeze()
|
||||
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
|
||||
@@ -47,8 +47,6 @@ from lerobot.processor import (
|
||||
RenameObservationsProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
to_absolute_actions,
|
||||
to_relative_actions,
|
||||
transition_to_batch,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
@@ -90,30 +88,6 @@ N1_7_EMBODIMENT_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
_GROOT_REF_HOLDER_KEY = "_groot_relative_ref_holder" # private; dropped by _filter_groot_inputs, never reaches the model
|
||||
|
||||
|
||||
class _GrootRelativeRefHolder:
|
||||
"""Runtime-only carrier shared (by object identity) between the pack step (owner/writer of the
|
||||
live reference), GrootPolicy.predict_action_chunk (freezes it at a real predict event), and the
|
||||
decode step (reads the frozen reference). Not serialized. One instance per pack step."""
|
||||
|
||||
__slots__ = ("reference_state", "raw_state", "frozen_reference", "frozen_raw")
|
||||
|
||||
def __init__(self):
|
||||
self.reference_state = None
|
||||
self.raw_state = None
|
||||
self.frozen_reference = None
|
||||
self.frozen_raw = None
|
||||
|
||||
def freeze(self) -> None:
|
||||
self.frozen_reference = self.reference_state
|
||||
self.frozen_raw = self.raw_state
|
||||
|
||||
def clear(self) -> None:
|
||||
self.reference_state = self.raw_state = self.frozen_reference = self.frozen_raw = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GrootN17CheckpointProcessorAssets:
|
||||
"""Processor metadata loaded from a raw Isaac-GR00T N1.7 checkpoint.
|
||||
@@ -143,39 +117,6 @@ class _GrootN17CheckpointProcessorAssets:
|
||||
use_albumentations: bool
|
||||
|
||||
|
||||
def _resolve_base_model_local_dir(base_model_path: str | None) -> str | None:
|
||||
"""Resolve a base model path to a local snapshot dir holding its sidecar JSONs.
|
||||
|
||||
``is_raw_groot_n1_7_checkpoint`` needs a local directory (or config.json) to inspect, so a
|
||||
bare HF repo-id (e.g. ``nvidia/GR00T-N1.7-3B``) would never be recognised as a raw N1.7
|
||||
checkpoint and the processor would fall back to LeRobot default image geometry instead of the
|
||||
checkpoint's processor_config.json geometry. When the path is not already a local dir, this
|
||||
downloads just the JSON sidecars and returns the local snapshot dir. Offline-safe: any failure
|
||||
returns the original string unchanged. Only used on the fresh-build (training) path; inference
|
||||
loads the serialized processor, so no per-inference network call is added.
|
||||
"""
|
||||
if base_model_path is None:
|
||||
return None
|
||||
if Path(base_model_path).expanduser().is_dir():
|
||||
return base_model_path
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = snapshot_download(
|
||||
base_model_path,
|
||||
repo_type="model",
|
||||
allow_patterns=["*.json"],
|
||||
)
|
||||
logging.debug(
|
||||
"Resolved GR00T base model '%s' to local snapshot '%s' for processor asset loading.",
|
||||
base_model_path,
|
||||
local_dir,
|
||||
)
|
||||
return local_dir
|
||||
except Exception: # noqa: BLE001 (offline-safe: fall back to the original path on any failure)
|
||||
return base_model_path
|
||||
|
||||
|
||||
def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17CheckpointProcessorAssets | None:
|
||||
"""Load N1.7 processor settings from checkpoint sidecar JSON files.
|
||||
|
||||
@@ -183,11 +124,10 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
can keep using caller-provided dataset stats and config values.
|
||||
"""
|
||||
|
||||
resolved_base_model_path = _resolve_base_model_local_dir(config.base_model_path)
|
||||
if not is_raw_groot_n1_7_checkpoint(resolved_base_model_path):
|
||||
if not is_raw_groot_n1_7_checkpoint(config.base_model_path):
|
||||
return None
|
||||
|
||||
checkpoint_path = Path(resolved_base_model_path).expanduser()
|
||||
checkpoint_path = Path(config.base_model_path).expanduser()
|
||||
processor_config = _read_json(checkpoint_path / "processor_config.json")
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
@@ -512,40 +452,6 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
def _build_relative_action_mask(
|
||||
action_dim: int,
|
||||
exclude_joints: list[str] | None,
|
||||
action_names: list[str] | None,
|
||||
) -> list[bool]:
|
||||
"""Build the per-dim relative-action mask (True = convert to relative, False = keep absolute).
|
||||
|
||||
Replicates ``RelativeActionsProcessorStep._build_mask`` semantics: dims are excluded
|
||||
(kept absolute) by case-insensitive token match against ``action_names``.
|
||||
|
||||
When ``action_names`` is None we cannot identify the gripper, so this returns all-True
|
||||
(every dim treated as relative). The user should ensure ``config.action_feature_names`` is
|
||||
populated (the factory does this from dataset meta) so the gripper can be kept absolute;
|
||||
arm-relative still works either way, but a missing-name gripper would be treated as relative.
|
||||
"""
|
||||
if not exclude_joints or action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask: list[bool] = []
|
||||
for name in action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
||||
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
||||
@@ -747,15 +653,8 @@ def _reconnect_groot_n1_7_pack_decode_steps(
|
||||
if pack_step is None:
|
||||
return
|
||||
|
||||
# Both decode steps read the pack step's cached state via a non-serialized ``pack_step`` link:
|
||||
# GrootN17ActionDecodeStep reads the per-modality raw state; the relative-action path
|
||||
# (GrootActionUnpackUnnormalizeStep) reads the cached reference state. Restore both links after
|
||||
# deserialization.
|
||||
for step in postprocessor.steps:
|
||||
if (
|
||||
isinstance(step, (GrootN17ActionDecodeStep, GrootActionUnpackUnnormalizeStep))
|
||||
and step.pack_step is None
|
||||
):
|
||||
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
|
||||
step.pack_step = pack_step
|
||||
|
||||
|
||||
@@ -833,9 +732,6 @@ def make_groot_pre_post_processors(
|
||||
video_modality_keys=video_modality_keys,
|
||||
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
||||
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
||||
use_relative_actions=config.use_relative_actions,
|
||||
relative_exclude_joints=config.relative_exclude_joints,
|
||||
action_feature_names=config.action_feature_names,
|
||||
)
|
||||
|
||||
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
|
||||
@@ -843,8 +739,7 @@ def make_groot_pre_post_processors(
|
||||
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
|
||||
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
|
||||
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
|
||||
# frames, inflating the VLM token count -- slowing both dataloading_s and update_s --
|
||||
# and feeding the model a resolution it was not trained on.
|
||||
# frames, inflating the VLM token count and feeding the model a resolution it was not trained on.
|
||||
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
|
||||
image_target_size = checkpoint_assets.image_target_size
|
||||
image_crop_size = checkpoint_assets.image_crop_size
|
||||
@@ -868,9 +763,6 @@ def make_groot_pre_post_processors(
|
||||
shortest_image_edge=shortest_image_edge,
|
||||
crop_fraction=crop_fraction,
|
||||
use_albumentations=use_albumentations,
|
||||
# Run the image resize/normalize/patchify on the training device when
|
||||
# possible instead of the single CPU main-loop thread (the dominant
|
||||
# cost folded into dataloading_s).
|
||||
device=config.device,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
@@ -895,10 +787,6 @@ def make_groot_pre_post_processors(
|
||||
stats=padded_stats,
|
||||
normalize_min_max=True,
|
||||
clip_normalized_action=True,
|
||||
use_relative_actions=config.use_relative_actions,
|
||||
relative_exclude_joints=config.relative_exclude_joints,
|
||||
action_feature_names=config.action_feature_names,
|
||||
pack_step=pack_step,
|
||||
)
|
||||
else:
|
||||
action_decode_step = GrootN17ActionDecodeStep(
|
||||
@@ -1031,15 +919,22 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
||||
return proc
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm(
|
||||
def _transform_n1_7_image_for_vlm_albumentations(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_crop_size: list[int] | None,
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
use_albumentations: bool = False,
|
||||
) -> Image.Image:
|
||||
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
||||
|
||||
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
|
||||
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
|
||||
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
|
||||
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
|
||||
torch path and must stay bit-exact to the upstream reference.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
|
||||
@@ -1047,71 +942,46 @@ def _transform_n1_7_image_for_vlm(
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
if use_albumentations:
|
||||
try:
|
||||
import cv2
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
||||
) from exc
|
||||
try:
|
||||
import cv2
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"GR00T N1.7 checkpoints with use_albumentations=True require opencv-python-headless."
|
||||
) from exc
|
||||
|
||||
image_np = np.asarray(image)
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
image_np = cv2.copyMakeBorder(
|
||||
image_np,
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0),
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
if image_np.shape[:2] != (resize_edge, resize_edge):
|
||||
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if crop_fraction is None and image_crop_size is not None:
|
||||
crop_fraction = image_crop_size[0] / float(target_h)
|
||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
||||
height, width = image_np.shape[:2]
|
||||
crop_h = max(1, int(height * crop_fraction))
|
||||
crop_w = max(1, int(width * crop_fraction))
|
||||
top = max(0, (height - crop_h) // 2)
|
||||
left = max(0, (width - crop_w) // 2)
|
||||
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
||||
|
||||
if image_np.shape[:2] != (target_h, target_w):
|
||||
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||
return Image.fromarray(image_np)
|
||||
|
||||
square_edge = max(image.width, image.height)
|
||||
if image.width != image.height:
|
||||
padded = Image.new("RGB", (square_edge, square_edge))
|
||||
left = (square_edge - image.width) // 2
|
||||
top = (square_edge - image.height) // 2
|
||||
padded.paste(image, (left, top))
|
||||
image = padded
|
||||
image_np = np.asarray(image)
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
image_np = cv2.copyMakeBorder(
|
||||
image_np,
|
||||
pad_h // 2,
|
||||
pad_h - pad_h // 2,
|
||||
pad_w // 2,
|
||||
pad_w - pad_w // 2,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0),
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC)
|
||||
if image_np.shape[:2] != (resize_edge, resize_edge):
|
||||
image_np = cv2.resize(image_np, (resize_edge, resize_edge), interpolation=cv2.INTER_AREA)
|
||||
|
||||
if crop_fraction is None and image_crop_size is not None:
|
||||
crop_fraction = image_crop_size[0] / float(target_h)
|
||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
||||
crop_w = max(1, int(round(image.width * crop_fraction)))
|
||||
crop_h = max(1, int(round(image.height * crop_fraction)))
|
||||
left = max(0, (image.width - crop_w) // 2)
|
||||
top = max(0, (image.height - crop_h) // 2)
|
||||
image = image.crop((left, top, left + crop_w, top + crop_h))
|
||||
height, width = image_np.shape[:2]
|
||||
crop_h = max(1, int(height * crop_fraction))
|
||||
crop_w = max(1, int(width * crop_fraction))
|
||||
top = max(0, (height - crop_h) // 2)
|
||||
left = max(0, (width - crop_w) // 2)
|
||||
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
||||
|
||||
if image.size != (target_w, target_h):
|
||||
image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
|
||||
return image
|
||||
if image_np.shape[:2] != (target_h, target_w):
|
||||
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||
return Image.fromarray(image_np)
|
||||
|
||||
|
||||
def _transform_n1_7_image_for_vlm_torch(
|
||||
@@ -1122,14 +992,15 @@ def _transform_n1_7_image_for_vlm_torch(
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
) -> torch.Tensor:
|
||||
"""Torch/torchvision port of the non-albumentations branch of
|
||||
:func:`_transform_n1_7_image_for_vlm`.
|
||||
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
|
||||
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
|
||||
|
||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
||||
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
||||
cv2/INTER_AREA path has no torch equivalent and stays on the PIL helper.
|
||||
cv2/INTER_AREA path has no torch equivalent and stays on
|
||||
:func:`_transform_n1_7_image_for_vlm_albumentations`.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
@@ -1195,18 +1066,8 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
video_modality_keys: list[str] | None = None
|
||||
raw_stats: dict[str, Any] | None = None
|
||||
modality_config: dict[str, Any] | None = None
|
||||
# Opt-in relative-action support: convert absolute->relative actions inside this pack step
|
||||
# (training) using the cached raw reference state, keeping excluded joints (e.g. gripper)
|
||||
# absolute. The paired GrootActionUnpackUnnormalizeStep reconstructs absolute on decode.
|
||||
use_relative_actions: bool = False
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
action_feature_names: list[str] | None = None
|
||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||
_last_reference_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||
_ref_holder: "_GrootRelativeRefHolder" = field(
|
||||
default_factory=_GrootRelativeRefHolder, init=False, repr=False
|
||||
)
|
||||
|
||||
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
||||
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
||||
@@ -1328,7 +1189,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
start_idx += dim
|
||||
if grouped:
|
||||
self._last_raw_state = grouped
|
||||
self._ref_holder.raw_state = grouped
|
||||
|
||||
img_keys = self._ordered_image_keys(obs)
|
||||
if img_keys:
|
||||
@@ -1348,9 +1208,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
formalize_language=self.formalize_language,
|
||||
)
|
||||
|
||||
# Reference state for relative-action conversion (RAW, pre-normalization, (B, D)). Cached
|
||||
# regardless of whether an action is present so inference caches it too for decode.
|
||||
relative_reference_state: torch.Tensor | None = None
|
||||
if OBS_STATE in obs:
|
||||
state = obs[OBS_STATE]
|
||||
if state.dim() != 2:
|
||||
@@ -1359,10 +1216,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
if dim > self.max_state_dim:
|
||||
raise ValueError(f"State dimension {dim} exceeds max_state_dim {self.max_state_dim}.")
|
||||
_cache_raw_state(state)
|
||||
if self.use_relative_actions:
|
||||
relative_reference_state = state.detach().clone()
|
||||
self._last_reference_state = relative_reference_state
|
||||
self._ref_holder.reference_state = relative_reference_state
|
||||
if self.normalize_min_max:
|
||||
state = _min_max_norm(state, OBS_STATE)
|
||||
state = state.unsqueeze(1)
|
||||
@@ -1385,19 +1238,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
raise ValueError(f"Action horizon {horizon} exceeds action_horizon {self.action_horizon}.")
|
||||
if dim > self.max_action_dim:
|
||||
raise ValueError(f"Action dimension {dim} exceeds max_action_dim {self.max_action_dim}.")
|
||||
# Convert absolute->relative BEFORE normalization. The mask keeps excluded joints (e.g.
|
||||
# gripper) absolute; to_relative_actions broadcasts the (B, D) reference state over T.
|
||||
if self.use_relative_actions:
|
||||
if relative_reference_state is None:
|
||||
raise RuntimeError(
|
||||
"GrootN17PackInputsStep.use_relative_actions requires observation.state "
|
||||
"(OBS_STATE) to be present alongside the action to build the relative "
|
||||
"reference, but no state was found in this transition."
|
||||
)
|
||||
mask = _build_relative_action_mask(
|
||||
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
|
||||
)
|
||||
action = to_relative_actions(action, relative_reference_state, mask)
|
||||
if self.normalize_min_max:
|
||||
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
|
||||
action = flat.view(bsz, horizon, dim)
|
||||
@@ -1437,12 +1277,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
comp["action_mask"] = action_mask
|
||||
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.int32, device=device)
|
||||
|
||||
# Publish the runtime-only reference holder so the policy can freeze it at the predict
|
||||
# event and the decode step can read the frozen reference. It rides in COMPLEMENTARY_DATA,
|
||||
# survives the VLM-encode step and DeviceProcessorStep as a non-tensor, and reaches the
|
||||
# policy via the batch (by object identity) through the pipeline's shallow copies.
|
||||
comp[_GROOT_REF_HOLDER_KEY] = self._ref_holder
|
||||
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return transition
|
||||
@@ -1467,9 +1301,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
"video_modality_keys": self.video_modality_keys,
|
||||
"raw_stats": self.raw_stats,
|
||||
"modality_config": self.modality_config,
|
||||
"use_relative_actions": self.use_relative_actions,
|
||||
"relative_exclude_joints": self.relative_exclude_joints,
|
||||
"action_feature_names": self.action_feature_names,
|
||||
}
|
||||
|
||||
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
||||
@@ -1477,23 +1308,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
|
||||
return self._last_raw_state
|
||||
|
||||
def get_cached_reference_state(self) -> torch.Tensor | None:
|
||||
"""Return the latest RAW (pre-normalization) (B, D) state used for relative-action conversion."""
|
||||
|
||||
return self._last_reference_state
|
||||
|
||||
def get_reference_holder(self) -> "_GrootRelativeRefHolder":
|
||||
"""Return the runtime-only holder shared with the policy (writer) and decode step (reader)."""
|
||||
|
||||
return self._ref_holder
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear cached per-episode relative-action references (sync engine resets on episode boundaries)."""
|
||||
|
||||
self._last_reference_state = None
|
||||
self._last_raw_state = None
|
||||
self._ref_holder.clear()
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
if not self.stats:
|
||||
return {}
|
||||
@@ -1527,9 +1341,9 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
|
||||
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
||||
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
||||
CUDA device, the resize/rescale/normalize/patchify run there instead of on the
|
||||
single CPU main-loop thread. This keeps the output bit-identical on CPU and
|
||||
moves the dominant preprocessing cost off the critical path on GPU.
|
||||
CUDA device, the resize/rescale/normalize/patchify run there. This keeps the
|
||||
output bit-identical on CPU and moves the dominant preprocessing cost off
|
||||
the critical path on GPU.
|
||||
"""
|
||||
|
||||
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
@@ -1573,13 +1387,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
video_np = np.asarray(video)
|
||||
return [
|
||||
[
|
||||
_transform_n1_7_image_for_vlm(
|
||||
_transform_n1_7_image_for_vlm_albumentations(
|
||||
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
use_albumentations=True,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
@@ -1890,14 +1703,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
start_idx += dim
|
||||
|
||||
if self.use_relative_action:
|
||||
# Prefer the raw state frozen at the chunk-prediction event (see the relative-action
|
||||
# branch of GrootActionUnpackUnnormalizeStep). Falls back to the live cached raw state.
|
||||
holder = self.pack_step.get_reference_holder() if self.pack_step is not None else None
|
||||
raw_state = None
|
||||
if holder is not None:
|
||||
raw_state = holder.frozen_raw if holder.frozen_raw is not None else holder.raw_state
|
||||
if raw_state is None and self.pack_step is not None:
|
||||
raw_state = self.pack_step.get_cached_raw_state()
|
||||
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
||||
if raw_state is None:
|
||||
raise RuntimeError(
|
||||
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
||||
@@ -1975,13 +1781,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
clip_normalized_action: bool = False
|
||||
libero_gripper_action: bool = False
|
||||
libero_gripper_binarize: bool = True
|
||||
# Opt-in relative-action reconstruction (paired with GrootN17PackInputsStep). After the
|
||||
# min-max inverse, relative deltas (arm) + absolute gripper are converted back to absolute
|
||||
# using the reference state cached by the linked pack_step (re-linked on reload).
|
||||
use_relative_actions: bool = False
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
action_feature_names: list[str] | None = None
|
||||
pack_step: "GrootN17PackInputsStep | None" = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
|
||||
@@ -2021,35 +1820,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
inv = (action + 1.0) * 0.5 * safe_denom + min_v
|
||||
action = torch.where(mask, inv, min_v)
|
||||
|
||||
# Reconstruct absolute actions from relative deltas (arm) + absolute gripper, using the
|
||||
# reference state cached by the linked pack step. The link is restored on reload by
|
||||
# _reconnect_groot_n1_7_pack_decode_steps.
|
||||
if self.use_relative_actions:
|
||||
if self.pack_step is None:
|
||||
raise RuntimeError(
|
||||
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires a linked "
|
||||
"GrootN17PackInputsStep to read the cached reference state, but pack_step is None. "
|
||||
"Build both pipelines through make_groot_pre_post_processors (or load them together "
|
||||
"via make_groot_pre_post_processors_from_pretrained)."
|
||||
)
|
||||
# Prefer the reference frozen at the chunk-prediction event (set by
|
||||
# GrootPolicy.predict_action_chunk via the shared holder) so every popped delta of a
|
||||
# chunk reconstructs against that chunk's start state S_T, not the per-tick latest
|
||||
# state. Falls back to the live reference when nothing was frozen (e.g. decode without
|
||||
# a preceding predict event, or RTC/async where frozen == live).
|
||||
holder = self.pack_step.get_reference_holder()
|
||||
ref = holder.frozen_reference if holder.frozen_reference is not None else holder.reference_state
|
||||
if ref is None:
|
||||
raise RuntimeError(
|
||||
"GrootActionUnpackUnnormalizeStep.use_relative_actions requires the reference state "
|
||||
"cached by its connected GrootN17PackInputsStep to convert relative actions back to "
|
||||
"absolute. Run the preprocessor on an observation before decoding actions."
|
||||
)
|
||||
relative_mask = _build_relative_action_mask(
|
||||
action.shape[-1], self.relative_exclude_joints, self.action_feature_names
|
||||
)
|
||||
action = to_absolute_actions(action, ref, relative_mask)
|
||||
|
||||
if self.libero_gripper_action and action.shape[-1] >= 7:
|
||||
gripper = action[..., -1]
|
||||
if self.libero_gripper_binarize:
|
||||
@@ -2077,9 +1847,6 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
"clip_normalized_action": self.clip_normalized_action,
|
||||
"libero_gripper_action": self.libero_gripper_action,
|
||||
"libero_gripper_binarize": self.libero_gripper_binarize,
|
||||
"use_relative_actions": self.use_relative_actions,
|
||||
"relative_exclude_joints": self.relative_exclude_joints,
|
||||
"action_feature_names": self.action_feature_names,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -41,7 +41,7 @@ from lerobot.policies.groot.processor_groot import (
|
||||
GrootN17ActionDecodeStep,
|
||||
GrootN17PackInputsStep,
|
||||
GrootN17VLMEncodeStep,
|
||||
_transform_n1_7_image_for_vlm,
|
||||
_transform_n1_7_image_for_vlm_albumentations,
|
||||
make_groot_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -1529,13 +1529,12 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
|
||||
|
||||
image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3)
|
||||
|
||||
transformed = _transform_n1_7_image_for_vlm(
|
||||
transformed = _transform_n1_7_image_for_vlm_albumentations(
|
||||
Image.fromarray(image_np),
|
||||
image_crop_size=[230, 230],
|
||||
image_target_size=[256, 256],
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
use_albumentations=True,
|
||||
)
|
||||
|
||||
expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)
|
||||
|
||||
Reference in New Issue
Block a user