Reconnect GR00T relative action processors

This commit is contained in:
Andrew Wrenn
2026-06-05 09:31:04 -07:00
parent 6803439f22
commit de1a9e5ad9
2 changed files with 111 additions and 1 deletions
@@ -33,12 +33,14 @@ else:
ProcessorMixin = object
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
batch_to_transition,
policy_action_to_transition,
@@ -542,9 +544,26 @@ def make_groot_pre_post_processors_from_pretrained(
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
def _reconnect_groot_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
) -> None:
relative_step = next(
(step for step in preprocessor.steps if isinstance(step, RelativeActionsProcessorStep)),
None,
)
if relative_step is None:
return
for step in postprocessor.steps:
if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None:
step.relative_step = relative_step
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
+92 -1
View File
@@ -44,7 +44,11 @@ from lerobot.policies.groot.processor_groot import (
_transform_n1_7_image_for_vlm,
make_groot_pre_post_processors,
)
from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor import (
AbsoluteActionsProcessorStep,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
@@ -701,6 +705,93 @@ def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_overrid
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
def test_converted_raw_n1_7_absolute_action_processors_load_without_relative_steps(tmp_path):
model_path = tmp_path / "libero_spatial"
_write_raw_n1_7_libero_checkpoint(model_path)
config = _raw_n1_7_libero_config(model_path)
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
save_dir = tmp_path / "absolute_pretrained_model"
config.save_pretrained(save_dir)
preprocessor.save_pretrained(save_dir)
postprocessor.save_pretrained(save_dir)
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(save_dir),
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
)
assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps)
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
assert not any(isinstance(step, RelativeActionsProcessorStep) for step in loaded_preprocessor.steps)
assert not any(isinstance(step, AbsoluteActionsProcessorStep) for step in loaded_postprocessor.steps)
def test_converted_raw_n1_7_relative_action_processors_reconnect_after_load(tmp_path):
model_path = tmp_path / "libero_spatial"
_write_raw_n1_7_libero_checkpoint(model_path)
config = _raw_n1_7_libero_config(model_path)
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
save_dir = tmp_path / "relative_pretrained_model"
action_names = [
"shoulder_pan.pos",
"shoulder_lift.pos",
"elbow_flex.pos",
"wrist_flex.pos",
"wrist_roll.pos",
"gripper.pos",
]
config.save_pretrained(save_dir)
preprocessor.save_pretrained(save_dir)
postprocessor.save_pretrained(save_dir)
preprocessor_config_path = save_dir / "policy_preprocessor.json"
preprocessor_config = json.loads(preprocessor_config_path.read_text())
preprocessor_config["steps"].insert(
2,
{
"registry_name": "relative_actions_processor",
"config": {
"enabled": True,
"exclude_joints": ["gripper"],
"action_names": action_names,
},
},
)
preprocessor_config_path.write_text(json.dumps(preprocessor_config, indent=4) + "\n")
postprocessor_config_path = save_dir / "policy_postprocessor.json"
postprocessor_config = json.loads(postprocessor_config_path.read_text())
postprocessor_config["steps"].insert(
-1,
{
"registry_name": "absolute_actions_processor",
"config": {"enabled": True},
},
)
postprocessor_config_path.write_text(json.dumps(postprocessor_config, indent=4) + "\n")
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
config,
pretrained_path=str(save_dir),
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
)
relative_step = next(
step for step in loaded_preprocessor.steps if isinstance(step, RelativeActionsProcessorStep)
)
absolute_step = next(
step for step in loaded_postprocessor.steps if isinstance(step, AbsoluteActionsProcessorStep)
)
assert relative_step.enabled is True
assert relative_step.exclude_joints == ["gripper"]
assert relative_step.action_names == action_names
assert absolute_step.relative_step is relative_step
def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max():
step = GrootN17PackInputsStep(
max_state_dim=2,