mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
Reconnect GR00T relative action processors
This commit is contained in:
@@ -33,12 +33,14 @@ else:
|
||||
ProcessorMixin = object
|
||||
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
@@ -542,9 +544,26 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def _reconnect_groot_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
) -> None:
|
||||
relative_step = next(
|
||||
(step for step in preprocessor.steps if isinstance(step, RelativeActionsProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is None:
|
||||
return
|
||||
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None:
|
||||
step.relative_step = relative_step
|
||||
|
||||
|
||||
def make_groot_pre_post_processors(
|
||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
) -> tuple[
|
||||
|
||||
@@ -44,7 +44,11 @@ from lerobot.policies.groot.processor_groot import (
|
||||
_transform_n1_7_image_for_vlm,
|
||||
make_groot_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
@@ -701,6 +705,93 @@ def test_converted_raw_n1_7_processors_load_without_legacy_action_unpack_overrid
|
||||
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
|
||||
|
||||
|
||||
def test_converted_raw_n1_7_absolute_action_processors_load_without_relative_steps(tmp_path):
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
_write_raw_n1_7_libero_checkpoint(model_path)
|
||||
config = _raw_n1_7_libero_config(model_path)
|
||||
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
|
||||
save_dir = tmp_path / "absolute_pretrained_model"
|
||||
|
||||
config.save_pretrained(save_dir)
|
||||
preprocessor.save_pretrained(save_dir)
|
||||
postprocessor.save_pretrained(save_dir)
|
||||
|
||||
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
|
||||
config,
|
||||
pretrained_path=str(save_dir),
|
||||
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
|
||||
)
|
||||
|
||||
assert any(isinstance(step, GrootN17PackInputsStep) for step in loaded_preprocessor.steps)
|
||||
assert any(isinstance(step, GrootN17ActionDecodeStep) for step in loaded_postprocessor.steps)
|
||||
assert not any(isinstance(step, RelativeActionsProcessorStep) for step in loaded_preprocessor.steps)
|
||||
assert not any(isinstance(step, AbsoluteActionsProcessorStep) for step in loaded_postprocessor.steps)
|
||||
|
||||
|
||||
def test_converted_raw_n1_7_relative_action_processors_reconnect_after_load(tmp_path):
|
||||
model_path = tmp_path / "libero_spatial"
|
||||
_write_raw_n1_7_libero_checkpoint(model_path)
|
||||
config = _raw_n1_7_libero_config(model_path)
|
||||
preprocessor, postprocessor = make_pre_post_processors(config, pretrained_path=str(model_path))
|
||||
save_dir = tmp_path / "relative_pretrained_model"
|
||||
action_names = [
|
||||
"shoulder_pan.pos",
|
||||
"shoulder_lift.pos",
|
||||
"elbow_flex.pos",
|
||||
"wrist_flex.pos",
|
||||
"wrist_roll.pos",
|
||||
"gripper.pos",
|
||||
]
|
||||
|
||||
config.save_pretrained(save_dir)
|
||||
preprocessor.save_pretrained(save_dir)
|
||||
postprocessor.save_pretrained(save_dir)
|
||||
|
||||
preprocessor_config_path = save_dir / "policy_preprocessor.json"
|
||||
preprocessor_config = json.loads(preprocessor_config_path.read_text())
|
||||
preprocessor_config["steps"].insert(
|
||||
2,
|
||||
{
|
||||
"registry_name": "relative_actions_processor",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"exclude_joints": ["gripper"],
|
||||
"action_names": action_names,
|
||||
},
|
||||
},
|
||||
)
|
||||
preprocessor_config_path.write_text(json.dumps(preprocessor_config, indent=4) + "\n")
|
||||
|
||||
postprocessor_config_path = save_dir / "policy_postprocessor.json"
|
||||
postprocessor_config = json.loads(postprocessor_config_path.read_text())
|
||||
postprocessor_config["steps"].insert(
|
||||
-1,
|
||||
{
|
||||
"registry_name": "absolute_actions_processor",
|
||||
"config": {"enabled": True},
|
||||
},
|
||||
)
|
||||
postprocessor_config_path.write_text(json.dumps(postprocessor_config, indent=4) + "\n")
|
||||
|
||||
loaded_preprocessor, loaded_postprocessor = make_pre_post_processors(
|
||||
config,
|
||||
pretrained_path=str(save_dir),
|
||||
preprocessor_overrides={"rename_observations_processor": {"rename_map": {}}},
|
||||
)
|
||||
|
||||
relative_step = next(
|
||||
step for step in loaded_preprocessor.steps if isinstance(step, RelativeActionsProcessorStep)
|
||||
)
|
||||
absolute_step = next(
|
||||
step for step in loaded_postprocessor.steps if isinstance(step, AbsoluteActionsProcessorStep)
|
||||
)
|
||||
|
||||
assert relative_step.enabled is True
|
||||
assert relative_step.exclude_joints == ["gripper"]
|
||||
assert relative_step.action_names == action_names
|
||||
assert absolute_step.relative_step is relative_step
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_inputs_rejects_state_dim_above_core_max():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=2,
|
||||
|
||||
Reference in New Issue
Block a user