fix(policies): crop losses based on the action dof (#3133)

Co-authored-by: Chenning Yu <rainorangelemon@gmail.com>
This commit is contained in:
Steven Palma
2026-03-12 00:51:31 +01:00
committed by GitHub
parent c15b75e3da
commit efee611403
@@ -377,6 +377,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
original_action_dim = self.config.action_feature.shape[0]
losses = losses[:, :, :original_action_dim]
loss_dict["losses_after_forward"] = losses.clone().mean().item()
if actions_is_pad is not None: