From 1add460678729bbd34d7a97f3fb455764f316f86 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 23 Apr 2026 15:23:54 +0200 Subject: [PATCH] fix(policy): loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT (#3442) * Fix loss normalization for padded actions in ACT, Diffusion, and MultiTaskDiT When action_is_pad masks out padded timesteps, the subsequent .mean() still divides by the total element count (including zeroed-out padding), underestimating the loss. With 60-70% padding this can cut the effective gradient signal by 2-3x. Replace mask-then-mean with mask-then-sum / valid-count for all three affected policies. TDMPC is not affected because it sums over time before averaging over batch. Fixes #3353 * linting Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Maxime Ellerbach * Update src/lerobot/policies/diffusion/modeling_diffusion.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma * Update src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> Signed-off-by: Steven Palma * apply ACT loss normalization suggestion from review Divide by num_valid (timesteps * action_dim) instead of just timesteps, matching the diffusion/multi_task_dit fix. Addresses review from @whats2000 (https://github.com/huggingface/lerobot/pull/3377#discussion_r3106845791). * fix(test): update safetensor act --------- Signed-off-by: Maxime Ellerbach Signed-off-by: Steven Palma Co-authored-by: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Co-authored-by: Maxime Ellerbach Co-authored-by: whats2000 <60466660+whats2000@users.noreply.github.com> --- src/lerobot/policies/act/modeling_act.py | 7 ++++--- src/lerobot/policies/diffusion/modeling_diffusion.py | 4 +++- .../policies/multi_task_dit/modeling_multi_task_dit.py | 10 ++++++---- .../actions.safetensors | 2 +- .../grad_stats.safetensors | 2 +- .../output_dict.safetensors | 2 +- .../param_stats.safetensors | 2 +- 7 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index 0120258ee..5651fbfb1 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy): actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none") + valid_mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = valid_mask.sum() * abs_err.shape[-1] + l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1) loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 03203ffc8..9fbe1f703 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -380,7 +380,9 @@ class DiffusionModel(nn.Module): f"{self.config.do_mask_loss_for_padding=}." ) in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound.unsqueeze(-1) + mask = in_episode_bound.unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 366b271c0..ceb4e211c 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module): loss = F.mse_loss(predicted, target, reduction="none") if self.do_mask_loss_for_padding and "action_is_pad" in batch: - valid_actions = ~batch["action_is_pad"] - loss = loss * valid_actions.unsqueeze(-1) + mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() @@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module): loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") if self.do_mask_loss_for_padding and "action_is_pad" in batch: - valid_mask = ~batch["action_is_pad"] - loss = loss * valid_mask.unsqueeze(-1) + mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors index dd7d4d0e7..5584b8643 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c +oid sha256:8a31653c11eccdd4d80fd3f6a351cd54c49b8a48db1f7e9faf38fddd7900a09f size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors index c58bb44bc..3baa80ba7 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6 +oid sha256:75bf051698b37dcd7517ec8025a896ab5a0551a6dde5f89d0a3d5d50966e83e6 size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors index 9b6ef7f5d..f3d1ff59a 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd +oid sha256:88e10930a10041d50f2cf369e6813ac14618d13dad1c21bdde1ac7798611c6ba size 68 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors index 5da67a1af..bdc26816f 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf +oid sha256:89833a5ccdb7d85c83f717ff8ec68b8e822005cb8803899acaae88c578e2e3ae size 31672