mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix(profiling): align libero smoke specs with pretrained policies
This commit is contained in:
@@ -57,7 +57,7 @@
|
|||||||
"--policy.device=cuda",
|
"--policy.device=cuda",
|
||||||
"--policy.n_action_steps=30",
|
"--policy.n_action_steps=30",
|
||||||
"--batch_size=1",
|
"--batch_size=1",
|
||||||
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"pi0_fast": {
|
"pi0_fast": {
|
||||||
@@ -69,7 +69,7 @@
|
|||||||
"--policy.device=cuda",
|
"--policy.device=cuda",
|
||||||
"--policy.n_action_steps=30",
|
"--policy.n_action_steps=30",
|
||||||
"--batch_size=1",
|
"--batch_size=1",
|
||||||
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"pi05": {
|
"pi05": {
|
||||||
@@ -81,7 +81,8 @@
|
|||||||
"--policy.device=cuda",
|
"--policy.device=cuda",
|
||||||
"--policy.n_action_steps=30",
|
"--policy.n_action_steps=30",
|
||||||
"--batch_size=1",
|
"--batch_size=1",
|
||||||
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
|
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}",
|
||||||
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"smolvla": {
|
"smolvla": {
|
||||||
@@ -96,7 +97,7 @@
|
|||||||
"--policy.empty_cameras=1",
|
"--policy.empty_cameras=1",
|
||||||
"--policy.device=cuda",
|
"--policy.device=cuda",
|
||||||
"--batch_size=1",
|
"--batch_size=1",
|
||||||
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
|
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", \"observation.images.wrist\": \"observation.images.camera2\"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"wall_x": {
|
"wall_x": {
|
||||||
|
|||||||
@@ -59,6 +59,38 @@ def test_profiling_specs_cover_expected_policies():
|
|||||||
assert excluded not in specs
|
assert excluded not in specs
|
||||||
|
|
||||||
|
|
||||||
|
def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
|
||||||
|
module = _import_model_profiling_script()
|
||||||
|
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
|
||||||
|
specs = module.load_specs(spec_path)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||||
|
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
|
in specs["pi0"]["train_args"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||||
|
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
|
in specs["pi0_fast"]["train_args"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", "
|
||||||
|
"\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"
|
||||||
|
in specs["pi05"]["train_args"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", "
|
||||||
|
"\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}"
|
||||||
|
in specs["pi05"]["train_args"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"--rename_map={\"observation.images.front\": \"observation.images.camera1\", "
|
||||||
|
"\"observation.images.wrist\": \"observation.images.camera2\"}"
|
||||||
|
in specs["smolvla"]["train_args"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_build_train_command_includes_profiling_outputs(tmp_path):
|
def test_build_train_command_includes_profiling_outputs(tmp_path):
|
||||||
module = _import_model_profiling_script()
|
module = _import_model_profiling_script()
|
||||||
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
|
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
|
||||||
|
|||||||
Reference in New Issue
Block a user