diff --git a/examples/training/smolvla2_hirobot.slurm b/examples/training/smolvla2_hirobot.slurm index c1f950e8b..2a3eac1f8 100644 --- a/examples/training/smolvla2_hirobot.slurm +++ b/examples/training/smolvla2_hirobot.slurm @@ -63,6 +63,8 @@ accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \ --policy.compile_model=false \ --policy.device=cuda \ --policy.tokenizer_max_length=512 \ + --policy.text_loss_weight=1.0 \ + --policy.flow_loss_weight=10.0 \ --steps="$STEPS" \ --policy.scheduler_decay_steps="$STEPS" \ --batch_size="$BATCH_SIZE" \ diff --git a/src/lerobot/policies/smolvla2/configuration_smolvla2.py b/src/lerobot/policies/smolvla2/configuration_smolvla2.py index 99ce917e3..bc24139fd 100644 --- a/src/lerobot/policies/smolvla2/configuration_smolvla2.py +++ b/src/lerobot/policies/smolvla2/configuration_smolvla2.py @@ -69,12 +69,23 @@ class SmolVLA2Config(SmolVLAConfig): matches its training distribution.""" # Loss weights -------------------------------------------------------- + # Pi 0.5 paper §IV.D (Eq. 1) sets α = 10 between the text-CE term + # and the flow-MSE term: L = H(text) + α * ‖ω - a - f_θ‖². The + # rationale is that actions are the primary output and the flow + # head should dominate the gradient signal; text is supervised as + # an auxiliary task and its CE scale (~0.5-2.0 in nats) tends to + # be larger than the flow MSE scale (~0.1-1.0), so without + # up-weighting the action head gets starved. We mirror the paper's + # split here: text_loss_weight=1, flow_loss_weight=10. text_loss_weight: float = 1.0 """Weight on the LM-head cross-entropy term. Set to ``0`` to disable text training entirely (reverts to flow-only / SmolVLA behaviour).""" - flow_loss_weight: float = 1.0 - """Weight on the action-expert flow-matching term.""" + flow_loss_weight: float = 10.0 + """Weight on the action-expert flow-matching term. Default 10.0 + matches Pi 0.5 paper's α (§IV.D). Set lower if the text head is + underfitting relative to the action expert; set higher if the + action expert is degrading because text loss dominates.""" # Backbone training --------------------------------------------------- unfreeze_lm_head: bool = True