From 2236cdb302cad685798c5f09ea3c713b824a104a Mon Sep 17 00:00:00 2001 From: whats2000 <60466660+whats2000@users.noreply.github.com> Date: Thu, 23 Apr 2026 16:34:11 +0800 Subject: [PATCH] fix(smolvla): correct loss normalization for padded actions (#3434) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the same per-scalar-mean fix to SmolVLA that #3377 landed for ACT / Diffusion / MultiTaskDiT. The pre-patch form applies the `action_is_pad` mask to zero out padded timesteps, then calls `.mean()` (or `.mean(dim=(1, 2))`). Because `.mean()` divides by the total number of elements including the zeroed padding, the loss is diluted by the padding fraction. Fixed by normalizing only over valid (non-padded) scalar entries: num_valid = ((~actions_is_pad).sum(...) * losses.shape[-1]).clamp_min(1) loss = losses.sum(...) / num_valid `clamp_min(1)` preserves the all-padded-batch edge case (0/1 = 0). Both reduction paths are updated. Behavior when `action_is_pad` is missing is unchanged (`losses.mean()`). Empirical A/B on aloha_sim_transfer_cube_human (chunk_size=40, batch=2, 30 steps, fixed seed, GB200) shows `loss_A / loss_B = 0.9672 (±0.088)` — same direction and magnitude as PR #3377's `loss_A / loss_C ≈ 0.96` for ACT. Heavier-padding recipes will see a larger gap. Refs: #3353 (original report for ACT), #3377 (fix for the other three policies). --- src/lerobot/policies/smolvla/modeling_smolvla.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 8ddb023da..8fd60c1fc 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy): loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() if reduction == "none": - # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) + # Return per-sample losses (B,) by averaging over valid (time, action) entries + if actions_is_pad is None: + per_sample_loss = losses.mean(dim=(1, 2)) + else: + num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1) + per_sample_loss = losses.sum(dim=(1, 2)) / num_valid loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: - # Default: return scalar mean loss - loss = losses.mean() + # Default: return scalar mean loss over valid (time, action) entries + if actions_is_pad is None: + loss = losses.mean() + else: + num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1) + loss = losses.sum() / num_valid loss_dict["loss"] = loss.item() return loss, loss_dict