mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fix(policies): crop losses based on the action dof (#3133)
Co-authored-by: Chenning Yu <rainorangelemon@gmail.com>
This commit is contained in:
@@ -377,6 +377,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
actions_is_pad = batch.get("action_is_pad")
|
actions_is_pad = batch.get("action_is_pad")
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
original_action_dim = self.config.action_feature.shape[0]
|
||||||
|
losses = losses[:, :, :original_action_dim]
|
||||||
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
loss_dict["losses_after_forward"] = losses.clone().mean().item()
|
||||||
|
|
||||||
if actions_is_pad is not None:
|
if actions_is_pad is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user