From a4544ffea7761f995fa16f52fc5605d054d3283b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 15:35:25 +0200 Subject: [PATCH] fix(profiling): use bf16 dtype and gradient checkpointing for pi0/pi05 Enable --policy.dtype=bfloat16 and --policy.gradient_checkpointing=true for pi0, pi0_fast, and pi05 profiling specs. Combined with use_amp=true, this brings the 4B-param VLA models well within the 22GB GPU budget. Made-with: Cursor --- profiling/model_profiling_specs.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/profiling/model_profiling_specs.json b/profiling/model_profiling_specs.json index 8774c993e..1d782fc3d 100644 --- a/profiling/model_profiling_specs.json +++ b/profiling/model_profiling_specs.json @@ -58,8 +58,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi0_base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] @@ -71,8 +73,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi0fast-base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] @@ -84,8 +88,10 @@ "--dataset.episodes=[0]", "--policy.path=lerobot/pi05_base", "--policy.device=cuda", + "--policy.dtype=bfloat16", "--policy.n_action_steps=30", "--policy.use_amp=true", + "--policy.gradient_checkpointing=true", "--batch_size=1", "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}"