diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 21b55987c..448fc34f8 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -448,8 +448,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): logs["action_loss"] = native_output["action_loss"].detach().item() if "wm_loss" in native_output: - wm_loss = native_output["wm_loss"] - logs["wm_loss"] = wm_loss.detach().item() + total_loss = total_loss + native_output["wm_loss"] + logs["wm_loss"] = native_output["wm_loss"].detach().item() logs["loss"] = ( total_loss.detach().item()