diff --git a/profiling/model_profiling_specs.json b/profiling/model_profiling_specs.json index 1d782fc3d..f30337663 100644 --- a/profiling/model_profiling_specs.json +++ b/profiling/model_profiling_specs.json @@ -63,6 +63,16 @@ "--policy.use_amp=true", "--policy.gradient_checkpointing=true", "--batch_size=1", + "--use_policy_training_preset=false", + "--optimizer.type=sgd", + "--optimizer.lr=1e-5", + "--optimizer.weight_decay=0", + "--optimizer.grad_clip_norm=1.0", + "--scheduler.type=cosine_decay_with_warmup", + "--scheduler.peak_lr=1e-5", + "--scheduler.decay_lr=1e-6", + "--scheduler.num_warmup_steps=0", + "--scheduler.num_decay_steps=12", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] }, @@ -78,6 +88,16 @@ "--policy.use_amp=true", "--policy.gradient_checkpointing=true", "--batch_size=1", + "--use_policy_training_preset=false", + "--optimizer.type=sgd", + "--optimizer.lr=1e-5", + "--optimizer.weight_decay=0", + "--optimizer.grad_clip_norm=1.0", + "--scheduler.type=cosine_decay_with_warmup", + "--scheduler.peak_lr=1e-5", + "--scheduler.decay_lr=1e-6", + "--scheduler.num_warmup_steps=0", + "--scheduler.num_decay_steps=12", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] }, @@ -93,6 +113,16 @@ "--policy.use_amp=true", "--policy.gradient_checkpointing=true", "--batch_size=1", + "--use_policy_training_preset=false", + "--optimizer.type=sgd", + "--optimizer.lr=1e-5", + "--optimizer.weight_decay=0", + "--optimizer.grad_clip_norm=1.0", + "--scheduler.type=cosine_decay_with_warmup", + "--scheduler.peak_lr=1e-5", + "--scheduler.decay_lr=1e-6", + "--scheduler.num_warmup_steps=0", + "--scheduler.num_decay_steps=12", "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", \"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}", "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", \"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" ] diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 5cfc66f3b..85931051d 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -343,6 +343,8 @@ class TrainingProfiler: output_dir=self._output_dir, device_type=self._device.type, ) + if self._device.type == "cuda": + torch.cuda.empty_cache() def __enter__(self) -> TrainingProfiler: if self._device.type == "cuda":