fix(profiling): align libero smoke specs with pretrained policies

This commit is contained in:
Pepijn
2026-04-16 15:11:54 +02:00
parent 8ece10e484
commit 4137b5785d
2 changed files with 37 additions and 4 deletions
+5 -4
View File
@@ -57,7 +57,7 @@
"--policy.device=cuda", "--policy.device=cuda",
"--policy.n_action_steps=30", "--policy.n_action_steps=30",
"--batch_size=1", "--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
] ]
}, },
"pi0_fast": { "pi0_fast": {
@@ -69,7 +69,7 @@
"--policy.device=cuda", "--policy.device=cuda",
"--policy.n_action_steps=30", "--policy.n_action_steps=30",
"--batch_size=1", "--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
] ]
}, },
"pi05": { "pi05": {
@@ -81,7 +81,8 @@
"--policy.device=cuda", "--policy.device=cuda",
"--policy.n_action_steps=30", "--policy.n_action_steps=30",
"--batch_size=1", "--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}",
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
] ]
}, },
"smolvla": { "smolvla": {
@@ -96,7 +97,7 @@
"--policy.empty_cameras=1", "--policy.empty_cameras=1",
"--policy.device=cuda", "--policy.device=cuda",
"--batch_size=1", "--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}" "--rename_map={\"observation.images.front\": \"observation.images.camera1\", \"observation.images.wrist\": \"observation.images.camera2\"}"
] ]
}, },
"wall_x": { "wall_x": {
+32
View File
@@ -59,6 +59,38 @@ def test_profiling_specs_cover_expected_policies():
assert excluded not in specs assert excluded not in specs
def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
specs = module.load_specs(spec_path)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0_fast"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi05"]["train_args"]
)
assert (
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
in specs["pi05"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
"\"observation.images.wrist\": \"observation.images.camera2\"}"
in specs["smolvla"]["train_args"]
)
def test_build_train_command_includes_profiling_outputs(tmp_path): def test_build_train_command_includes_profiling_outputs(tmp_path):
module = _import_model_profiling_script() module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json" spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"