mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442)
* Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT When action_is_pad masks out padded timesteps, the subsequent .mean() still divides by the total element count (including zeroed-out padding), underestimating the loss. With 60-70% padding this can cut the effective gradient signal by 2-3x. Replace mask-then-mean with mask-then-sum / valid-count for all three affected policies. TDMPC is not affected because it sums over time before averaging over batch. Fixes #3353 * linting Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * Update src/lerobot/policies/diffusion/modeling_diffusion.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * apply ACT loss normalization suggestion from review Divide by num_valid (timesteps * action_dim) instead of just timesteps, matching the diffusion/multi_task_dit fix. Addresses review from @whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791). * fix(test): update safetensor act --------- Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Maxime Ellerbach <maxime@ellerbach.net> Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com>
This commit is contained in:
@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
|
||||
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = valid_mask.sum() * abs_err.shape[-1]
|
||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
|
||||
@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
mask = in_episode_bound.unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c
|
||||
oid sha256:8a31653c11eccdd4d80fd3f6a351cd54c49b8a48db1f7e9faf38fddd7900a09f
|
||||
size 515400
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
||||
oid sha256:75bf051698b37dcd7517ec8025a896ab5a0551a6dde5f89d0a3d5d50966e83e6
|
||||
size 31672
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
||||
oid sha256:88e10930a10041d50f2cf369e6813ac14618d13dad1c21bdde1ac7798611c6ba
|
||||
size 68
|
||||
|
||||
+1
-1
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf
|
||||
oid sha256:89833a5ccdb7d85c83f717ff8ec68b8e822005cb8803899acaae88c578e2e3ae
|
||||
size 31672
|
||||
|
||||
Reference in New Issue
Block a user