mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(smolvla): correct loss normalization for padded actions (#3434)
Apply the same per-scalar-mean fix to SmolVLA that #3377 landed for ACT / Diffusion / MultiTaskDiT. The pre-patch form applies the `action_is_pad` mask to zero out padded timesteps, then calls `.mean()` (or `.mean(dim=(1, 2))`). Because `.mean()` divides by the total number of elements including the zeroed padding, the loss is diluted by the padding fraction. Fixed by normalizing only over valid (non-padded) scalar entries: num_valid = ((~actions_is_pad).sum(...) * losses.shape[-1]).clamp_min(1) loss = losses.sum(...) / num_valid `clamp_min(1)` preserves the all-padded-batch edge case (0/1 = 0). Both reduction paths are updated. Behavior when `action_is_pad` is missing is unchanged (`losses.mean()`). Empirical A/B on aloha_sim_transfer_cube_human (chunk_size=40, batch=2, 30 steps, fixed seed, GB200) shows `loss_A / loss_B = 0.9672 (±0.088)` — same direction and magnitude as PR #3377's `loss_A / loss_C ≈ 0.96` for ACT. Heavier-padding recipes will see a larger gap. Refs: #3353 (original report for ACT), #3377 (fix for the other three policies).
This commit is contained in:
@@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
# Return per-sample losses (B,) by averaging over valid (time, action) entries
|
||||
if actions_is_pad is None:
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
else:
|
||||
num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1)
|
||||
per_sample_loss = losses.sum(dim=(1, 2)) / num_valid
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
# Default: return scalar mean loss over valid (time, action) entries
|
||||
if actions_is_pad is None:
|
||||
loss = losses.mean()
|
||||
else:
|
||||
num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1)
|
||||
loss = losses.sum() / num_valid
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user