mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-04 00:27:15 +00:00
chore(policies): add guards, warnings and comments + recover tests n1.5 check
This commit is contained in:
@@ -324,9 +324,14 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Enable GR00T-style state-relative action chunks. Prefer deriving action representation from
|
||||
# embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it.
|
||||
# Enable GR00T-style state-relative action chunks (action chunk expressed relative to the current
|
||||
# observation state).
|
||||
use_relative_actions: bool = False
|
||||
|
||||
# relative_exclude_joints names the action dimensions that stay absolute; the
|
||||
# match is substring/case-insensitive against the dataset action feature names. With the empty
|
||||
# default every dimension is treated as relative, including the gripper -- set e.g. ["gripper"] to
|
||||
# keep the gripper absolute, matching the Isaac-GR00T single-arm + absolute-gripper convention.
|
||||
relative_exclude_joints: list[str] = field(default_factory=list)
|
||||
|
||||
# Training parameters
|
||||
|
||||
@@ -996,6 +996,7 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
}
|
||||
for group in groups
|
||||
]
|
||||
# 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B)
|
||||
action_horizon = min(config.chunk_size, 40)
|
||||
modality_config: dict[str, Any] = {
|
||||
"state": {"modality_keys": [group.key for group in groups]},
|
||||
@@ -1194,6 +1195,13 @@ def make_groot_pre_post_processors(
|
||||
)
|
||||
relative_step: RelativeActionsProcessorStep | None = None
|
||||
if config.use_relative_actions and not uses_native_relative_actions:
|
||||
logging.warning(
|
||||
"GR00T relative actions are using the generic RelativeActionsProcessorStep fallback because "
|
||||
"the checkpoint already carries non-relative statistics. Relative deltas will be normalized "
|
||||
"with absolute action stats rather than Isaac-GR00T's per-horizon relative stats. For "
|
||||
"OSS-faithful relative normalization, build from a checkpoint without baked-in stats (or "
|
||||
"pass dataset_meta) so native relative stats are computed."
|
||||
)
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=list(config.relative_exclude_joints or []),
|
||||
@@ -1658,6 +1666,25 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
return None
|
||||
return torch.cat(normalized_groups, dim=-1)
|
||||
|
||||
def _uses_relative_action_groups(self) -> bool:
|
||||
"""True when the action modality declares at least one relative group.
|
||||
|
||||
Relative groups normalize with per-chunk-timestep (2D) ``relative_action`` stats, which the
|
||||
flat ``_min_max_norm`` fallback cannot honor, so a relative config that fails grouped
|
||||
normalization must fail loudly rather than silently mis-scale every timestep.
|
||||
"""
|
||||
if not isinstance(self.modality_config, dict):
|
||||
return False
|
||||
action_config = self.modality_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return False
|
||||
action_configs = action_config.get("action_configs", [])
|
||||
if not isinstance(action_configs, list):
|
||||
return False
|
||||
return any(
|
||||
isinstance(cfg, dict) and config_value(cfg.get("rep")) == "relative" for cfg in action_configs
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
@@ -1775,6 +1802,15 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
normalized_action = self._normalize_action_groups_for_training(action)
|
||||
if normalized_action is not None:
|
||||
action = normalized_action
|
||||
elif self._uses_relative_action_groups():
|
||||
raise ValueError(
|
||||
"GrootN17PackInputsStep could not apply native grouped normalization to a "
|
||||
"relative-action chunk: the action layout or horizon does not match the "
|
||||
f"checkpoint relative_action stats (action shape {tuple(action.shape)}). The flat "
|
||||
"min/max fallback cannot honor per-chunk-timestep relative stats, so refusing to "
|
||||
"silently mis-normalize. Recompute the relative action stats so their horizon and "
|
||||
"dimensions match the action chunk."
|
||||
)
|
||||
else:
|
||||
flat = _min_max_norm(action.reshape(bsz * horizon, dim), ACTION)
|
||||
action = flat.view(bsz, horizon, dim)
|
||||
|
||||
@@ -30,10 +30,12 @@ from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.policies.factory import make_policy_config, make_pre_post_processors
|
||||
from lerobot.policies.groot.configuration_groot import (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
GROOT_N1_7,
|
||||
GROOT_N1_7_BASE_MODEL,
|
||||
GrootConfig,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
from lerobot.policies.groot.processor_groot import (
|
||||
@@ -350,6 +352,18 @@ def test_groot_defaults_use_n1_7():
|
||||
assert len(config.action_delta_indices) == 40
|
||||
|
||||
|
||||
@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"])
|
||||
def test_groot_normalize_model_version_rejects_n1_5_aliases(legacy_version):
|
||||
# model_version is no longer a GrootConfig field, but normalize_groot_model_version is still
|
||||
# live (e.g. via infer_groot_model_version) and must keep rejecting N1.5 with removal guidance.
|
||||
with pytest.raises(ValueError, match="Unsupported GR00T model_version"):
|
||||
normalize_groot_model_version(legacy_version)
|
||||
|
||||
|
||||
def test_groot_normalize_model_version_accepts_n1_7():
|
||||
assert normalize_groot_model_version(GROOT_N1_7) == GROOT_N1_7
|
||||
|
||||
|
||||
def test_groot_n1_7_accepts_named_action_decode_transform():
|
||||
config = GrootConfig(
|
||||
action_decode_transform="libero",
|
||||
@@ -997,6 +1011,42 @@ def test_groot_n1_7_pack_inputs_normalizes_action_chunk_per_dimension_before_pad
|
||||
assert action_mask[0, :, 3:].sum().item() == 0
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_raises_when_relative_groups_cannot_normalize():
|
||||
# Relative groups carry per-chunk-timestep stats; if the action horizon exceeds the available
|
||||
# stat rows, grouped normalization cannot apply and the flat fallback would silently mis-scale.
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=3,
|
||||
valid_action_horizon=3,
|
||||
max_state_dim=2,
|
||||
max_action_dim=2,
|
||||
normalize_min_max=True,
|
||||
raw_stats={
|
||||
"state": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
|
||||
"action": {"single_arm": {"min": [0.0, 0.0], "max": [1.0, 1.0]}},
|
||||
# only one horizon row, but the action chunk has horizon 3
|
||||
"relative_action": {"single_arm": {"min": [[-1.0, -1.0]], "max": [[1.0, 1.0]]}},
|
||||
},
|
||||
modality_config={
|
||||
"state": {"modality_keys": ["single_arm"]},
|
||||
"action": {
|
||||
"modality_keys": ["single_arm"],
|
||||
"action_configs": [
|
||||
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None}
|
||||
],
|
||||
"delta_indices": [0, 1, 2],
|
||||
},
|
||||
},
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.zeros(1, 2)},
|
||||
TransitionKey.ACTION: torch.zeros(1, 3, 2),
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="could not apply native grouped normalization"):
|
||||
step(transition)
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_trains_native_relative_groups_with_absolute_gripper():
|
||||
step = GrootN17PackInputsStep(
|
||||
action_horizon=2,
|
||||
|
||||
Reference in New Issue
Block a user