From de1a9e5ad956937b29c6747aedf02898fbe710c2 Mon Sep 17 00:00:00 2001 From: Andrew Wrenn Date: Fri, 5 Jun 2026 09:31:04 -0700 Subject: [PATCH] Reconnect GR00T relative action processors --- src/lerobot/policies/groot/processor_groot.py | 19 ++++ tests/policies/groot/test_groot_n1_7.py | 93 ++++++++++++++++++- 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 4e98df259..f9143734c 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -33,12 +33,14 @@ else: ProcessorMixin = object from lerobot.processor import ( + AbsoluteActionsProcessorStep, AddBatchDimensionProcessorStep, DeviceProcessorStep, PolicyAction, PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, + RelativeActionsProcessorStep, RenameObservationsProcessorStep, batch_to_transition, policy_action_to_transition, @@ -542,9 +544,26 @@ def make_groot_pre_post_processors_from_pretrained( to_transition=policy_action_to_transition, to_output=transition_to_policy_action, ) + _reconnect_groot_relative_absolute_steps(preprocessor, postprocessor) return preprocessor, postprocessor +def _reconnect_groot_relative_absolute_steps( + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, +) -> None: + relative_step = next( + (step for step in preprocessor.steps if isinstance(step, RelativeActionsProcessorStep)), + None, + ) + if relative_step is None: + return + + for step in postprocessor.steps: + if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None: + step.relative_step = relative_step + + def make_groot_pre_post_processors( config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None ) -> tuple[ diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 918162e18..5f06bb73e 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -44,7 +44,11 @@ from lerobot.policies.groot.processor_groot import ( _transform_n1_7_image_for_vlm, make_groot_pre_post_processors, ) -from lerobot.processor import PolicyProcessorPipeline +from lerobot.processor import ( + AbsoluteActionsProcessorStep, + PolicyProcessorPipeline, + RelativeActionsProcessorStep, +) from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE @@ -701,6 +705,93 @@ def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_overrid assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps) +def test_converted_raw_n1_7_absolute_action_processors_load_without_relative_steps(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "absolute_pretrained_model" + + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(save_dir), + preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}}, + ) + + assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps) + assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps) + assert not any(isinstance(step, RelativeActionsProcessorStep) for step in loaded_preprocessor.steps) + assert not any(isinstance(step, AbsoluteActionsProcessorStep) for step in loaded_postprocessor.steps) + + +def test_converted_raw_n1_7_relative_action_processors_reconnect_after_load(tmp_path): + model_path = tmp_path / "libero_spatial" + _write_raw_n1_7_libero_checkpoint(model_path) + config = _raw_n1_7_libero_config(model_path) + preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path)) + save_dir = tmp_path / "relative_pretrained_model" + action_names = [ + "shoulder_pan.pos", + "shoulder_lift.pos", + "elbow_flex.pos", + "wrist_flex.pos", + "wrist_roll.pos", + "gripper.pos", + ] + + config.save_pretrained(save_dir) + preprocessor.save_pretrained(save_dir) + postprocessor.save_pretrained(save_dir) + + preprocessor_config_path = save_dir / "policy_preprocessor.json" + preprocessor_config = json.loads(preprocessor_config_path.read_text()) + preprocessor_config["steps"].insert( + 2, + { + "registry_name": "relative_actions_processor", + "config": { + "enabled": True, + "exclude_joints": ["gripper"], + "action_names": action_names, + }, + }, + ) + preprocessor_config_path.write_text(json.dumps(preprocessor_config, indent=4) + "\n") + + postprocessor_config_path = save_dir / "policy_postprocessor.json" + postprocessor_config = json.loads(postprocessor_config_path.read_text()) + postprocessor_config["steps"].insert( + -1, + { + "registry_name": "absolute_actions_processor", + "config": {"enabled": True}, + }, + ) + postprocessor_config_path.write_text(json.dumps(postprocessor_config, indent=4) + "\n") + + loaded_preprocessor, loaded_postprocessor = make_pre_post_processors( + config, + pretrained_path=str(save_dir), + preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}}, + ) + + relative_step = next( + step for step in loaded_preprocessor.steps if isinstance(step, RelativeActionsProcessorStep) + ) + absolute_step = next( + step for step in loaded_postprocessor.steps if isinstance(step, AbsoluteActionsProcessorStep) + ) + + assert relative_step.enabled is True + assert relative_step.exclude_joints == ["gripper"] + assert relative_step.action_names == action_names + assert absolute_step.relative_step is relative_step + + def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max(): step = GrootN17PackInputsStep( max_state_dim=2,