diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 8ddb023da..8fd60c1fc 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -394,13 +394,21 @@ class SmolVLAPolicy(PreTrainedPolicy): loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() if reduction == "none": - # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) + # Return per-sample losses (B,) by averaging over valid (time, action) entries + if actions_is_pad is None: + per_sample_loss = losses.mean(dim=(1, 2)) + else: + num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1) + per_sample_loss = losses.sum(dim=(1, 2)) / num_valid loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: - # Default: return scalar mean loss - loss = losses.mean() + # Default: return scalar mean loss over valid (time, action) entries + if actions_is_pad is None: + loss = losses.mean() + else: + num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1) + loss = losses.sum() / num_valid loss_dict["loss"] = loss.item() return loss, loss_dict