fix(pi052): supervise only FAST action-code tokens

Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 17:38:34 +00:00
parent 474c5478d9
commit 0e2dc1b76f
4 changed files with 134 additions and 30 deletions
+27 -14
View File
@@ -124,19 +124,19 @@ def _mark_target_span_causal(
def _fast_ce( def _fast_ce(
fast_logits: Tensor, fast_logits: Tensor,
action_tokens: Tensor, action_tokens: Tensor,
action_mask: Tensor, action_code_mask: Tensor,
predict_actions_t: Tensor | None, predict_actions_t: Tensor | None,
) -> Tensor: ) -> Tensor:
"""FAST-CE with both token-pad masking and per-sample action gating. """FAST action-code CE with token-span masking and per-sample action gating.
``action_mask`` is the FAST tokenizer's padding mask; samples whose ``action_code_mask`` is true only on the discrete action-code tokens,
recipe sets ``predict_actions=False`` (e.g. plan_generation, excluding the BOS / "Action: " / delimiter wrapper. Samples whose
ask_vqa_*) get *all* their FAST positions masked out via the recipe sets ``predict_actions=False`` get all code positions masked
per-sample gate. out via the per-sample gate.
""" """
shift_logits = fast_logits[:, :-1, :].contiguous() shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long() shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool() shift_valid = action_code_mask[:, 1:].contiguous().bool()
if predict_actions_t is not None: if predict_actions_t is not None:
sample_mask = predict_actions_t[:, None].expand_as(shift_valid) sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
shift_valid = shift_valid & sample_mask shift_valid = shift_valid & sample_mask
@@ -429,13 +429,18 @@ class PI052Policy(PI05Policy):
and self.config.fast_action_loss_weight > 0 and self.config.fast_action_loss_weight > 0
and (predict_actions_t is None or bool(predict_actions_t.any().item())) and (predict_actions_t is None or bool(predict_actions_t.any().item()))
) )
action_tokens = action_mask = None action_tokens = action_mask = action_code_mask = None
if run_fast: if run_fast:
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415 from lerobot.utils.constants import ( # noqa: PLC0415
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
)
action_tokens = batch.get(ACTION_TOKENS) action_tokens = batch.get(ACTION_TOKENS)
action_mask = batch.get(ACTION_TOKEN_MASK) action_mask = batch.get(ACTION_TOKEN_MASK)
if action_tokens is None or action_mask is None: action_code_mask = batch.get(ACTION_CODE_TOKEN_MASK)
if action_tokens is None or action_mask is None or action_code_mask is None:
run_fast = False run_fast = False
# ------------------------------------------------------------ # ------------------------------------------------------------
@@ -457,6 +462,7 @@ class PI052Policy(PI05Policy):
text_labels=text_labels if run_text else None, text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None, action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None, action_mask=action_mask if run_fast else None,
action_code_mask=action_code_mask if run_fast else None,
predict_actions_t=predict_actions_t, predict_actions_t=predict_actions_t,
) )
loss_dict["flow_loss"] = float(flow_loss.detach().item()) loss_dict["flow_loss"] = float(flow_loss.detach().item())
@@ -473,6 +479,7 @@ class PI052Policy(PI05Policy):
text_labels=text_labels if run_text else None, text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None, action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None, action_mask=action_mask if run_fast else None,
action_code_mask=action_code_mask if run_fast else None,
predict_actions_t=predict_actions_t, predict_actions_t=predict_actions_t,
) )
if text_loss is not None: if text_loss is not None:
@@ -508,6 +515,7 @@ class PI052Policy(PI05Policy):
text_labels: Tensor | None, text_labels: Tensor | None,
action_tokens: Tensor | None, action_tokens: Tensor | None,
action_mask: Tensor | None, action_mask: Tensor | None,
action_code_mask: Tensor | None,
predict_actions_t: Tensor | None = None, predict_actions_t: Tensor | None = None,
) -> tuple[Tensor, Tensor | None, Tensor | None]: ) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Full fusion: flow + text + FAST in ONE backbone forward. """Full fusion: flow + text + FAST in ONE backbone forward.
@@ -647,10 +655,10 @@ class PI052Policy(PI05Policy):
text_loss = _shifted_ce(text_logits, text_labels) text_loss = _shifted_ce(text_logits, text_labels)
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None: if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
fast_hidden = prefix_out[:, -fast_len:, :] fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t) fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
return flow_loss, text_loss, fast_loss return flow_loss, text_loss, fast_loss
@@ -660,6 +668,7 @@ class PI052Policy(PI05Policy):
text_labels: Tensor | None, text_labels: Tensor | None,
action_tokens: Tensor | None, action_tokens: Tensor | None,
action_mask: Tensor | None, action_mask: Tensor | None,
action_code_mask: Tensor | None,
predict_actions_t: Tensor | None = None, predict_actions_t: Tensor | None = None,
) -> tuple[Tensor | None, Tensor | None]: ) -> tuple[Tensor | None, Tensor | None]:
"""Single prefix forward → text CE + FAST CE. """Single prefix forward → text CE + FAST CE.
@@ -745,10 +754,14 @@ class PI052Policy(PI05Policy):
text_loss = _shifted_ce(text_logits, text_labels) text_loss = _shifted_ce(text_logits, text_labels)
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if action_tokens is not None and action_mask is not None and fast_len > 0: if (
action_tokens is not None
and action_code_mask is not None
and fast_len > 0
):
fast_hidden = vlm_out[:, -fast_len:, :] fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t) fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
return text_loss, fast_loss return text_loss, fast_loss
+31 -16
View File
@@ -32,6 +32,7 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK, ACTION_TOKEN_MASK,
ACTION_TOKENS, ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# During inference, no action is available, skip tokenization # During inference, no action is available, skip tokenization
return new_transition return new_transition
# Tokenize and get both tokens and mask # Tokenize and get masks for the full formatted sequence and the discrete action codes.
tokens, mask = self._tokenize_action(action) tokens, mask, code_mask = self._tokenize_action(action)
# Store mask in complementary data # Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None: if complementary_data is None:
complementary_data = {} complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
complementary_data[ACTION_TOKENS] = tokens complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition return new_transition
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
""" """
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Tokenizes the action tensor and creates a mask. Tokenizes the action tensor and creates a mask.
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# The fast tokenizer expects action data and returns token IDs # The fast tokenizer expects action data and returns token IDs
tokens_list = [] tokens_list = []
masks_list = [] masks_list = []
code_masks_list = []
for i in range(batch_size): for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1: if tokens.dim() > 1:
tokens = tokens.flatten() tokens = tokens.flatten()
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
bos_id = self._paligemma_tokenizer.bos_token_id bos_id = self._paligemma_tokenizer.bos_token_id
# add bos prompt_tokens = torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
)
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
code_start = 1 + len(prompt_tokens)
code_end = code_start + len(action_code_tokens)
tokens = torch.cat( tokens = torch.cat(
[ [
torch.tensor([bos_id], device=action.device), torch.tensor([bos_id], device=action.device),
torch.tensor( prompt_tokens,
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), action_code_tokens,
device=action.device, end_tokens,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
] ]
) )
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
code_mask[code_start:code_end] = True
# Truncate or pad to max_action_tokens # Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens: if len(tokens) > self.max_action_tokens:
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Consider increasing the `max_action_tokens` in your model config if this happens frequently." "Consider increasing the `max_action_tokens` in your model config if this happens frequently."
) )
tokens = tokens[: self.max_action_tokens] tokens = tokens[: self.max_action_tokens]
code_mask = code_mask[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else: else:
pad_len = self.max_action_tokens - len(tokens)
mask = torch.cat( mask = torch.cat(
[ [
torch.ones(len(tokens), dtype=torch.bool, device=action.device), torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros( torch.zeros(pad_len, dtype=torch.bool, device=action.device),
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
] ]
) )
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
# Pad tokens with zeros # Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0) tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
tokens_list.append(tokens) tokens_list.append(tokens)
masks_list.append(mask) masks_list.append(mask)
code_masks_list.append(code_mask)
# Stack into batched tensors # Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample # Remove batch dimension if input was single sample
if single_sample: if single_sample:
tokens_batch = tokens_batch.squeeze(0) tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0) masks_batch = masks_batch.squeeze(0)
code_masks_batch = code_masks_batch.squeeze(0)
# Move to the same device as the input # Move to the same device as the input
if device is not None: if device is not None:
tokens_batch = tokens_batch.to(device) tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device) masks_batch = masks_batch.to(device)
code_masks_batch = code_masks_batch.to(device)
return tokens_batch, masks_batch return tokens_batch, masks_batch, code_masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor: def action(self, action: torch.Tensor) -> torch.Tensor:
""" """
This method is not used since we override __call__. This method is not used since we override __call__.
Required by ActionProcessorStep ABC. Required by ActionProcessorStep ABC.
""" """
tokens, _ = self._tokenize_action(action) tokens, _, _ = self._tokenize_action(action)
return tokens return tokens
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
+1
View File
@@ -34,6 +34,7 @@ ACTION = "action"
ACTION_PREFIX = ACTION + "." ACTION_PREFIX = ACTION + "."
ACTION_TOKENS = ACTION + ".tokens" ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask" ACTION_TOKEN_MASK = ACTION + ".token_mask"
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
REWARD = "next.reward" REWARD = "next.reward"
TRUNCATED = "next.truncated" TRUNCATED = "next.truncated"
DONE = "next.done" DONE = "next.done"
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Regression tests for PI052 FAST action-code supervision."""
import pytest
import torch
from torch.nn import functional as F
pytest.importorskip("transformers")
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
def test_fast_ce_supervises_only_discrete_action_codes():
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
# Deliberately bad wrapper-token predictions. These should be ignored.
logits[0, 0, 7] = 10.0 # target would be token 2
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
# Correct action-code predictions: hidden t predicts target t + 1.
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
assert torch.allclose(loss, expected)
def test_fast_ce_masks_non_action_samples():
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.tensor(
[[False, False, True, True], [False, False, True, True]]
)
predict_actions = torch.tensor([True, False])
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
# Bad predictions in the masked sample should not matter.
logits[1, 1, 7] = 10.0
logits[1, 2, 7] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
assert torch.allclose(loss, expected)