diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index b1181765e..a13767794 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -1117,7 +1117,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` "flow_mse_loss": flow_loss.mean(), "action_ce_loss": fast_loss.mean(), "subtask_ce_loss": subtask_loss, - "loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights + "loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights } @torch.no_grad() # see openpi `sample_actions` (slightly adapted)