diff --git a/profiling/model_profiling_specs.json b/profiling/model_profiling_specs.json index 8774c993e..1d782fc3d 100644 --- a/profiling/model_profiling_specs.json +++ b/profiling/model_profiling_specs.json @@ -58,8 +58,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi0_base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] @@ -71,8 +73,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi0fast-base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] @@ -84,8 +88,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi05_base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"