fix(profiling): align libero smoke specs with pretrained policies

This commit is contained in:
Pepijn
2026-04-16 15:11:54 +02:00
parent 8ece10e484
commit 4137b5785d
2 changed files with 37 additions and 4 deletions
+32
View File
@@ -59,6 +59,38 @@ def test_profiling_specs_cover_expected_policies():
assert excluded not in specs
def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
specs = module.load_specs(spec_path)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi0_fast"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
in specs["pi05"]["train_args"]
)
assert (
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
in specs["pi05"]["train_args"]
)
assert (
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
"\"observation.images.wrist\": \"observation.images.camera2\"}"
in specs["smolvla"]["train_args"]
)
def test_build_train_command_includes_profiling_outputs(tmp_path):
module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"