diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index 0120258ee..5651fbfb1 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy): actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) - l1_loss = ( - F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) - ).mean() + abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none") + valid_mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = valid_mask.sum() * abs_err.shape[-1] + l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1) loss_dict = {"l1_loss": l1_loss.item()} if self.config.use_vae: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 03203ffc8..9fbe1f703 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -380,7 +380,9 @@ class DiffusionModel(nn.Module): f"{self.config.do_mask_loss_for_padding=}." ) in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound.unsqueeze(-1) + mask = in_episode_bound.unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 366b271c0..ceb4e211c 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module): loss = F.mse_loss(predicted, target, reduction="none") if self.do_mask_loss_for_padding and "action_is_pad" in batch: - valid_actions = ~batch["action_is_pad"] - loss = loss * valid_actions.unsqueeze(-1) + mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() @@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module): loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none") if self.do_mask_loss_for_padding and "action_is_pad" in batch: - valid_mask = ~batch["action_is_pad"] - loss = loss * valid_mask.unsqueeze(-1) + mask = ~batch["action_is_pad"].unsqueeze(-1) + num_valid = mask.sum() * loss.shape[-1] + return (loss * mask).sum() / num_valid.clamp_min(1) return loss.mean() diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors index dd7d4d0e7..5584b8643 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c +oid sha256:8a31653c11eccdd4d80fd3f6a351cd54c49b8a48db1f7e9faf38fddd7900a09f size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors index c58bb44bc..3baa80ba7 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6 +oid sha256:75bf051698b37dcd7517ec8025a896ab5a0551a6dde5f89d0a3d5d50966e83e6 size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors index 9b6ef7f5d..f3d1ff59a 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd +oid sha256:88e10930a10041d50f2cf369e6813ac14618d13dad1c21bdde1ac7798611c6ba size 68 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors index 5da67a1af..bdc26816f 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf +oid sha256:89833a5ccdb7d85c83f717ff8ec68b8e822005cb8803899acaae88c578e2e3ae size 31672