mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442)
* Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT When action_is_pad masks out padded timesteps, the subsequent .mean() still divides by the total element count (including zeroed-out padding), underestimating the loss. With 60-70% padding this can cut the effective gradient signal by 2-3x. Replace mask-then-mean with mask-then-sum / valid-count for all three affected policies. TDMPC is not affected because it sums over time before averaging over batch. Fixes #3353 * linting Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * Update src/lerobot/policies/diffusion/modeling_diffusion.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * apply ACT loss normalization suggestion from review Divide by num_valid (timesteps * action_dim) instead of just timesteps, matching the diffusion/multi_task_dit fix. Addresses review from @whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791). * fix(test): update safetensor act --------- Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Maxime Ellerbach <maxime@ellerbach.net> Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
This commit is contained in:
@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
|
||||
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = valid_mask.sum() * abs_err.shape[-1]
|
||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
|
||||
@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
mask = in_episode_bound.unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user