From ea535ad98d693f9583dd32208c1583ae552452c2 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Wed, 13 May 2026 12:38:28 +0000 Subject: [PATCH] fixing wm_loss not propagating --- src/lerobot/policies/vla_jepa/modeling_vla_jepa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 21b55987c..448fc34f8 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -448,8 +448,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): logs["action_loss"] = native_output["action_loss"].detach().item() if "wm_loss" in native_output: - wm_loss = native_output["wm_loss"] - logs["wm_loss"] = wm_loss.detach().item() + total_loss = total_loss + native_output["wm_loss"] + logs["wm_loss"] = native_output["wm_loss"].detach().item() logs["loss"] = ( total_loss.detach().item()