From 4137b5785d3336b250d5dd419192b769aa96b9bb Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 15:11:54 +0200 Subject: [PATCH] fix(profiling): align libero smoke specs with pretrained policies --- profiling/model_profiling_specs.json | 9 ++++---- tests/scripts/test_model_profiling.py | 32 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/profiling/model_profiling_specs.json b/profiling/model_profiling_specs.json index f283c18cd..3f526882d 100644 --- a/profiling/model_profiling_specs.json +++ b/profiling/model_profiling_specs.json @@ -57,7 +57,7 @@ "--policy.device=cuda", "--policy.n_action_steps=30", "--batch_size=1", - "--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] }, "pi0_fast": { @@ -69,7 +69,7 @@ "--policy.device=cuda", "--policy.n_action_steps=30", "--batch_size=1", - "--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] }, "pi05": { @@ -81,7 +81,8 @@ "--policy.device=cuda", "--policy.n_action_steps=30", "--batch_size=1", - "--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" + "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}", + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] }, "smolvla": { @@ -96,7 +97,7 @@ "--policy.empty_cameras=1", "--policy.device=cuda", "--batch_size=1", - "--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" + "--rename_map={\"observation.images.front\": \"observation.images.camera1\", \"observation.images.wrist\": \"observation.images.camera2\"}" ] }, "wall_x": { diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index 8463cd215..535d71150 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -59,6 +59,38 @@ def test_profiling_specs_cover_expected_policies(): assert excluded not in specs +def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization(): + module = _import_model_profiling_script() + spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json" + specs = module.load_specs(spec_path) + + assert ( + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " + "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" + in specs["pi0"]["train_args"] + ) + assert ( + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " + "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" + in specs["pi0_fast"]["train_args"] + ) + assert ( + "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " + "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" + in specs["pi05"]["train_args"] + ) + assert ( + "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", " + "\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}" + in specs["pi05"]["train_args"] + ) + assert ( + "--rename_map={\"observation.images.front\": \"observation.images.camera1\", " + "\"observation.images.wrist\": \"observation.images.camera2\"}" + in specs["smolvla"]["train_args"] + ) + + def test_build_train_command_includes_profiling_outputs(tmp_path): module = _import_model_profiling_script() spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"