mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
propagate action_is_pad masking through VLA-JEPA policy pipeline
Pass the `action_is_pad` tensor from the batch through to the action head so padded timesteps are excluded from the flow-matching loss.
This commit is contained in:
@@ -123,6 +123,11 @@ class VLAJEPAModel(nn.Module):
|
|||||||
actions = [ex["action"] for ex in examples] if has_action else None
|
actions = [ex["action"] for ex in examples] if has_action else None
|
||||||
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||||
state = [ex["state"] for ex in examples] if has_state else None
|
state = [ex["state"] for ex in examples] if has_state else None
|
||||||
|
action_is_pad = (
|
||||||
|
[ex["action_is_pad"] for ex in examples]
|
||||||
|
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||||
batch_videos = np.stack(batch_videos)
|
batch_videos = np.stack(batch_videos)
|
||||||
@@ -243,7 +248,20 @@ class VLAJEPAModel(nn.Module):
|
|||||||
actions_rep = actions_target.repeat(num_repeated, 1, 1)
|
actions_rep = actions_target.repeat(num_repeated, 1, 1)
|
||||||
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
|
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
|
||||||
|
|
||||||
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
|
action_is_pad_rep = None
|
||||||
|
if action_is_pad is not None:
|
||||||
|
pad_tensor = torch.stack(
|
||||||
|
[
|
||||||
|
p.to(actions_target.device)
|
||||||
|
if isinstance(p, Tensor)
|
||||||
|
else torch.tensor(p, device=actions_target.device)
|
||||||
|
for p in action_is_pad
|
||||||
|
]
|
||||||
|
) # [B, T_full]
|
||||||
|
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||||
|
action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon]
|
||||||
|
|
||||||
|
action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep)
|
||||||
|
|
||||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||||
|
|
||||||
@@ -414,11 +432,15 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# ---- Collect actions (training only) ----
|
# ---- Collect actions (training only) ----
|
||||||
actions_list = None
|
actions_list = None
|
||||||
|
action_is_pad_list = None
|
||||||
actions_tensor = batch.get(ACTION)
|
actions_tensor = batch.get(ACTION)
|
||||||
if actions_tensor is not None:
|
if actions_tensor is not None:
|
||||||
if actions_tensor.ndim == 2:
|
if actions_tensor.ndim == 2:
|
||||||
actions_tensor = actions_tensor.unsqueeze(1)
|
actions_tensor = actions_tensor.unsqueeze(1)
|
||||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||||
|
action_is_pad_tensor = batch.get("action_is_pad")
|
||||||
|
if action_is_pad_tensor is not None:
|
||||||
|
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||||
|
|
||||||
# ---- Collect state ----
|
# ---- Collect state ----
|
||||||
state_list = None
|
state_list = None
|
||||||
@@ -440,6 +462,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
}
|
}
|
||||||
if actions_list is not None:
|
if actions_list is not None:
|
||||||
example["action"] = actions_list[b]
|
example["action"] = actions_list[b]
|
||||||
|
if action_is_pad_list is not None:
|
||||||
|
example["action_is_pad"] = action_is_pad_list[b]
|
||||||
if state_list is not None:
|
if state_list is not None:
|
||||||
example["state"] = state_list[b]
|
example["state"] = state_list[b]
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
|
|||||||
Reference in New Issue
Block a user