fixing wm_loss not propagating

This commit is contained in:
Maxime Ellerbach
2026-05-13 12:38:28 +00:00
committed by Maximellerbach
parent 7368a0085a
commit ea535ad98d
@@ -448,8 +448,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
logs["action_loss"] = native_output["action_loss"].detach().item()
if "wm_loss" in native_output:
wm_loss = native_output["wm_loss"]
logs["wm_loss"] = wm_loss.detach().item()
total_loss = total_loss + native_output["wm_loss"]
logs["wm_loss"] = native_output["wm_loss"].detach().item()
logs["loss"] = (
total_loss.detach().item()