From 55da4bf8aa6def40d8959905c8fb06007b11667c Mon Sep 17 00:00:00 2001 From: Andy Wrenn Date: Sat, 20 Jun 2026 06:30:50 -0700 Subject: [PATCH] Address GROOT relative action review feedback --- src/lerobot/policies/factory.py | 2 + .../policies/groot/configuration_groot.py | 12 ++--- src/lerobot/policies/groot/processor_groot.py | 18 ++++++- src/lerobot/scripts/lerobot_train.py | 12 ++--- tests/policies/groot/test_groot_n1_7.py | 52 ++++++------------- 5 files changed, 43 insertions(+), 53 deletions(-) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 8c2d3e070..2d8de4000 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -287,6 +287,7 @@ def make_pre_post_processors( config=policy_cfg, pretrained_path=pretrained_path, dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), preprocessor_overrides=kwargs.get("preprocessor_overrides"), postprocessor_overrides=kwargs.get("postprocessor_overrides"), preprocessor_config_filename=kwargs.get( @@ -399,6 +400,7 @@ def make_pre_post_processors( processors = make_groot_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), + dataset_meta=kwargs.get("dataset_meta"), ) elif isinstance(policy_cfg, XVLAConfig): diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 004fa4be7..cd803f31a 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -268,7 +268,6 @@ class GrootConfig(PreTrainedConfig): ) # Groot-specific model parameters - model_version: str = GROOT_N1_7 # Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and # checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the @@ -325,11 +324,10 @@ class GrootConfig(PreTrainedConfig): # Set to True only after installing a flash-attn build matching your torch/CUDA env. use_flash_attention: bool = False - # Train on state-relative action chunks. The listed joints stay absolute, which is normally used - # for gripper channels whose command frame is not the arm joint state. + # Enable GR00T-style state-relative action chunks. Prefer deriving action representation from + # embodiment metadata; relative_exclude_joints is a flat-vector override for datasets without it. use_relative_actions: bool = False relative_exclude_joints: list[str] = field(default_factory=list) - action_feature_names: list[str] | None = None # Training parameters optimizer_lr: float = 1e-4 @@ -365,8 +363,6 @@ class GrootConfig(PreTrainedConfig): resume: bool = False def __post_init__(self): - self.model_version = normalize_groot_model_version(self.model_version) - if self.tokenizer_assets_repo is not None: raise ValueError( "Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks " @@ -417,9 +413,9 @@ class GrootConfig(PreTrainedConfig): setattr(self, field_name, n1_7_value) inferred_version = infer_groot_model_version(self.base_model_path) - if inferred_version is not None and inferred_version != self.model_version: + if inferred_version is not None and inferred_version != GROOT_N1_7: message = ( - f"GR00T model_version '{self.model_version}' does not match base_model_path " + f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path " f"'{self.base_model_path}', which looks like '{inferred_version}'." ) if inferred_version == GROOT_N1_5: diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 9fa7575f1..f91ebcdda 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -435,6 +435,7 @@ def make_groot_pre_post_processors_from_pretrained( pretrained_path: str, *, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta: Any | None = None, preprocessor_overrides: dict[str, Any] | None = None, postprocessor_overrides: dict[str, Any] | None = None, preprocessor_config_filename: str = f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json", @@ -457,6 +458,7 @@ def make_groot_pre_post_processors_from_pretrained( preprocessor, postprocessor = make_groot_pre_post_processors( config=processor_cfg, dataset_stats=dataset_stats, + dataset_meta=dataset_meta, ) # Raw checkpoints have no serialized pipelines to load overrides into, # so apply the caller overrides (e.g. device and rename_map from @@ -546,8 +548,20 @@ def _reconnect_groot_n1_7_pack_decode_steps( step.pack_step = pack_step +def _resolve_action_feature_names_from_dataset_meta(dataset_meta: Any | None) -> list[str] | None: + features = getattr(dataset_meta, "features", {}) or {} + action_feature = features.get(ACTION) if isinstance(features, dict) else None + if isinstance(action_feature, dict): + names = action_feature.get("names") + else: + names = getattr(action_feature, "names", None) + return list(names) if names is not None else None + + def make_groot_pre_post_processors( - config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None + config: GrootConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, + dataset_meta: Any | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], @@ -660,7 +674,7 @@ def make_groot_pre_post_processors( relative_step = RelativeActionsProcessorStep( enabled=True, exclude_joints=list(config.relative_exclude_joints or []), - action_names=config.action_feature_names, + action_names=_resolve_action_feature_names_from_dataset_meta(dataset_meta), ) input_steps.insert(2, relative_step) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index d99564fd4..eed3f1178 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -246,11 +246,7 @@ def _iter_action_state_training_samples(dataset: Any): yield item.get(ACTION), item.get(OBS_STATE), item.get(f"{ACTION}_is_pad") -def _resolve_action_feature_names(active_cfg: Any, dataset: Any) -> list[str] | None: - config_names = getattr(active_cfg, "action_feature_names", None) - if config_names is not None: - return list(config_names) - +def _resolve_action_feature_names(dataset: Any) -> list[str] | None: features = getattr(getattr(dataset, "meta", None), "features", {}) or {} action_feature = features.get(ACTION) if isinstance(features, dict) else None if isinstance(action_feature, dict): @@ -464,14 +460,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): processor_stats = _make_relative_action_training_stats( dataset, exclude_joints=getattr(active_cfg, "relative_exclude_joints", []), - action_names=_resolve_action_feature_names(active_cfg, dataset), + action_names=_resolve_action_feature_names(dataset), ) processor_kwargs = {} if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path: processor_kwargs["dataset_stats"] = processor_stats - if cfg.is_reward_model_training: + if cfg.is_reward_model_training or getattr(active_cfg, "use_relative_actions", False): processor_kwargs["dataset_meta"] = dataset.meta if not cfg.is_reward_model_training and processor_pretrained_path is not None: @@ -495,7 +491,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): preprocessor_overrides["relative_actions_processor"] = { "enabled": True, "exclude_joints": getattr(active_cfg, "relative_exclude_joints", []), - "action_names": _resolve_action_feature_names(active_cfg, dataset), + "action_names": _resolve_action_feature_names(dataset), } postprocessor_overrides["absolute_actions_processor"] = {"enabled": True} processor_kwargs["preprocessor_overrides"] = preprocessor_overrides diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 5eba6f075..b1ef56495 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -67,11 +67,10 @@ def _groot_features( ) -def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig: +def _groot_config() -> GrootConfig: input_features, output_features = _groot_features(state_dim=8, action_dim=7) kwargs = {"action_decode_transform": GROOT_ACTION_DECODE_TRANSFORM_LIBERO} return GrootConfig( - model_version=model_version, input_features=input_features, output_features=output_features, device="cpu", @@ -83,7 +82,6 @@ def _groot_config(model_version: str = GROOT_N1_7) -> GrootConfig: def _raw_n1_7_libero_config(model_path) -> GrootConfig: input_features, output_features = _groot_features(state_dim=8, action_dim=7) return GrootConfig( - model_version=GROOT_N1_7, base_model_path=str(model_path), embodiment_tag="libero_sim", input_features=input_features, @@ -351,7 +349,6 @@ class _DummyGrootModel(nn.Module): def test_groot_defaults_use_n1_7(): config = GrootConfig(device="cpu") - assert config.model_version == GROOT_N1_7 assert config.base_model_path == GROOT_N1_7_BASE_MODEL assert config.max_state_dim == 132 assert config.max_action_dim == 132 @@ -362,7 +359,6 @@ def test_groot_defaults_use_n1_7(): def test_groot_n1_7_accepts_named_action_decode_transform(): config = GrootConfig( - model_version=GROOT_N1_7, action_decode_transform="libero", device="cpu", ) @@ -374,23 +370,15 @@ def test_groot_n1_7_accepts_named_action_decode_transform(): def test_groot_n1_7_rejects_legacy_libero_gripper_action_decode_transform(legacy_transform): with pytest.raises(ValueError, match="Unsupported GR00T N1.7 action decode transform"): GrootConfig( - model_version=GROOT_N1_7, - action_decode_transform=legacy_transform, + action_decode_transform=legacy_transform, device="cpu", ) -@pytest.mark.parametrize("legacy_version", ["n1.5", "n1_5", "n15", "1.5"]) -def test_groot_rejects_n1_5_aliases(legacy_version): - with pytest.raises(ValueError, match="Unsupported GR00T model_version"): - GrootConfig(model_version=legacy_version, device="cpu") - - def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7(): with pytest.raises(ValueError, match="does not match base_model_path"): GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", + base_model_path="nvidia/GR00T-N1.5-3B", device="cpu", ) @@ -398,10 +386,9 @@ def test_groot_config_rejects_mismatched_n1_5_path_for_n1_7(): def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_gr00t(): sys.modules.pop("gr00t", None) - config = make_policy_config("groot", model_version=GROOT_N1_7, device="cpu") + config = make_policy_config("groot", device="cpu") assert isinstance(config, GrootConfig) - assert config.model_version == GROOT_N1_7 assert "gr00t" not in sys.modules @@ -417,7 +404,7 @@ def test_groot_predict_action_chunk_forwards_n1_7_rtc_prefix(monkeypatch): dummy_model = _DummyGrootModel() monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) - config = _groot_config(GROOT_N1_7) + config = _groot_config() policy = GrootPolicy(config) policy.config.rtc_config = SimpleNamespace(execution_horizon=6) @@ -446,7 +433,7 @@ def test_groot_predict_action_chunk_strips_padded_n1_7_rtc_prefix(monkeypatch): dummy_model = _DummyGrootModel() monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) - config = _groot_config(GROOT_N1_7) + config = _groot_config() policy = GrootPolicy(config) policy.config.rtc_config = SimpleNamespace(execution_horizon=6) @@ -490,7 +477,6 @@ def test_groot_n1_7_predict_action_chunk_truncates_to_checkpoint_valid_horizon(t monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel())) input_features, output_features = _groot_features(state_dim=8, action_dim=7) config = GrootConfig( - model_version=GROOT_N1_7, base_model_path=str(model_path), embodiment_tag="libero_sim", input_features=input_features, @@ -518,8 +504,7 @@ def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path): # so construction itself raises before from_pretrained is reached. with pytest.raises(ValueError, match="does not match base_model_path"): config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", + base_model_path="nvidia/GR00T-N1.5-3B", input_features=input_features, output_features=output_features, device="cpu", @@ -534,13 +519,12 @@ def test_groot_from_pretrained_keeps_matching_caller_config(tmp_path, monkeypatc model_path = tmp_path / "GR00T-N1.7-local" model_path.mkdir() - config = _groot_config(GROOT_N1_7) + config = _groot_config() monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: _DummyGrootModel())) policy = GrootPolicy.from_pretrained(model_path, config=config) - assert policy.config.model_version == GROOT_N1_7 assert policy.config.base_model_path == str(model_path) @@ -555,7 +539,6 @@ def test_groot_from_pretrained_infers_n1_7_from_ambiguous_local_config(tmp_path, policy = GrootPolicy.from_pretrained(model_path) - assert policy.config.model_version == GROOT_N1_7 assert policy.config.base_model_path == str(model_path) @@ -1340,8 +1323,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config( # so construction itself raises before from_pretrained is reached. with pytest.raises(ValueError, match="does not match base_model_path"): config = GrootConfig( - model_version=GROOT_N1_7, - base_model_path="nvidia/GR00T-N1.5-3B", + base_model_path="nvidia/GR00T-N1.5-3B", input_features=input_features, output_features=output_features, device="cpu", @@ -1353,7 +1335,7 @@ def test_groot_from_pretrained_rejects_caller_config_mismatch_from_local_config( def test_groot_n1_7_processors_are_registered_lazily_without_external_gr00t(): sys.modules.pop("gr00t", None) - config = _groot_config(GROOT_N1_7) + config = _groot_config() preprocessor, _ = make_groot_pre_post_processors(config) step_types = {type(step) for step in preprocessor.steps} @@ -1692,7 +1674,7 @@ def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch): def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path): - config = _groot_config(GROOT_N1_7) + config = _groot_config() dataset_stats = { OBS_STATE: { "min": torch.zeros(8), @@ -1724,7 +1706,7 @@ def test_groot_n1_7_saved_processors_reload_through_factory(tmp_path): def test_groot_n1_7_saved_processors_reload_through_factory_preserves_saved_stats(tmp_path): - config = _groot_config(GROOT_N1_7) + config = _groot_config() saved_stats = { OBS_STATE: { "min": torch.full((8,), -2.0), @@ -1774,7 +1756,6 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta action_decode_transform=None, use_relative_actions=True, relative_exclude_joints=["gripper"], - action_feature_names=action_names, ) absolute_dataset_stats = { OBS_STATE: { @@ -1829,7 +1810,9 @@ def test_groot_n1_7_relative_action_training_processors_save_relative_action_sta "max": torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 100.0]), } - preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=relative_dataset_stats) + preprocessor, postprocessor = make_groot_pre_post_processors( + config, dataset_stats=relative_dataset_stats, dataset_meta=_RelativeStatsDataset.meta + ) preprocessor.save_pretrained(tmp_path) postprocessor.save_pretrained(tmp_path) @@ -1867,7 +1850,7 @@ def test_groot_policy_selects_n1_7_model_class(monkeypatch): monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(fake_from_pretrained)) - policy = GrootPolicy(_groot_config(GROOT_N1_7)) + policy = GrootPolicy(_groot_config()) assert called["pretrained_model_name_or_path"] == GROOT_N1_7_BASE_MODEL assert isinstance(policy._groot_model, _DummyGrootModel) @@ -1878,7 +1861,7 @@ def test_groot_policy_forwards_n1_7_qwen_inputs(monkeypatch): dummy_model = _DummyGrootModel() monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: dummy_model)) - policy = GrootPolicy(_groot_config(GROOT_N1_7)) + policy = GrootPolicy(_groot_config()) batch = { "state": torch.zeros(2, 1, 132), @@ -1941,7 +1924,6 @@ def test_groot_n1_7_select_action_uses_checkpoint_valid_horizon(tmp_path, monkey monkeypatch.setattr(GR00TN17, "from_pretrained", classmethod(lambda cls, **kwargs: HorizonModel())) input_features, output_features = _groot_features(state_dim=8, action_dim=7) config = GrootConfig( - model_version=GROOT_N1_7, base_model_path=str(model_path), embodiment_tag="libero_sim", input_features=input_features,