diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 9069b05ca..c429830b7 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -124,19 +124,19 @@ def _mark_target_span_causal( def _fast_ce( fast_logits: Tensor, action_tokens: Tensor, - action_mask: Tensor, + action_code_mask: Tensor, predict_actions_t: Tensor | None, ) -> Tensor: - """FAST-CE with both token-pad masking and per-sample action gating. + """FAST action-code CE with token-span masking and per-sample action gating. - ``action_mask`` is the FAST tokenizer's padding mask; samples whose - recipe sets ``predict_actions=False`` (e.g. plan_generation, - ask_vqa_*) get *all* their FAST positions masked out via the - per-sample gate. + ``action_code_mask`` is true only on the discrete action-code tokens, + excluding the BOS / "Action: " / delimiter wrapper. Samples whose + recipe sets ``predict_actions=False`` get all code positions masked + out via the per-sample gate. """ shift_logits = fast_logits[:, :-1, :].contiguous() shift_targets = action_tokens[:, 1:].contiguous().long() - shift_valid = action_mask[:, 1:].contiguous().bool() + shift_valid = action_code_mask[:, 1:].contiguous().bool() if predict_actions_t is not None: sample_mask = predict_actions_t[:, None].expand_as(shift_valid) shift_valid = shift_valid & sample_mask @@ -429,13 +429,18 @@ class PI052Policy(PI05Policy): and self.config.fast_action_loss_weight > 0 and (predict_actions_t is None or bool(predict_actions_t.any().item())) ) - action_tokens = action_mask = None + action_tokens = action_mask = action_code_mask = None if run_fast: - from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415 + from lerobot.utils.constants import ( # noqa: PLC0415 + ACTION_CODE_TOKEN_MASK, + ACTION_TOKEN_MASK, + ACTION_TOKENS, + ) action_tokens = batch.get(ACTION_TOKENS) action_mask = batch.get(ACTION_TOKEN_MASK) - if action_tokens is None or action_mask is None: + action_code_mask = batch.get(ACTION_CODE_TOKEN_MASK) + if action_tokens is None or action_mask is None or action_code_mask is None: run_fast = False # ------------------------------------------------------------ @@ -457,6 +462,7 @@ class PI052Policy(PI05Policy): text_labels=text_labels if run_text else None, action_tokens=action_tokens if run_fast else None, action_mask=action_mask if run_fast else None, + action_code_mask=action_code_mask if run_fast else None, predict_actions_t=predict_actions_t, ) loss_dict["flow_loss"] = float(flow_loss.detach().item()) @@ -473,6 +479,7 @@ class PI052Policy(PI05Policy): text_labels=text_labels if run_text else None, action_tokens=action_tokens if run_fast else None, action_mask=action_mask if run_fast else None, + action_code_mask=action_code_mask if run_fast else None, predict_actions_t=predict_actions_t, ) if text_loss is not None: @@ -508,6 +515,7 @@ class PI052Policy(PI05Policy): text_labels: Tensor | None, action_tokens: Tensor | None, action_mask: Tensor | None, + action_code_mask: Tensor | None, predict_actions_t: Tensor | None = None, ) -> tuple[Tensor, Tensor | None, Tensor | None]: """Full fusion: flow + text + FAST in ONE backbone forward. @@ -647,10 +655,10 @@ class PI052Policy(PI05Policy): text_loss = _shifted_ce(text_logits, text_labels) fast_loss: Tensor | None = None - if fast_len > 0 and prefix_out is not None: + if fast_len > 0 and prefix_out is not None and action_code_mask is not None: fast_hidden = prefix_out[:, -fast_len:, :] fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) - fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t) + fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) return flow_loss, text_loss, fast_loss @@ -660,6 +668,7 @@ class PI052Policy(PI05Policy): text_labels: Tensor | None, action_tokens: Tensor | None, action_mask: Tensor | None, + action_code_mask: Tensor | None, predict_actions_t: Tensor | None = None, ) -> tuple[Tensor | None, Tensor | None]: """Single prefix forward → text CE + FAST CE. @@ -745,10 +754,14 @@ class PI052Policy(PI05Policy): text_loss = _shifted_ce(text_logits, text_labels) fast_loss: Tensor | None = None - if action_tokens is not None and action_mask is not None and fast_len > 0: + if ( + action_tokens is not None + and action_code_mask is not None + and fast_len > 0 + ): fast_hidden = vlm_out[:, -fast_len:, :] fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) - fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t) + fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) return text_loss, fast_loss diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index a808e6127..f11f9f87a 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -32,6 +32,7 @@ import torch from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.utils.constants import ( + ACTION_CODE_TOKEN_MASK, ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, @@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): # During inference, no action is available, skip tokenization return new_transition - # Tokenize and get both tokens and mask - tokens, mask = self._tokenize_action(action) + # Tokenize and get masks for the full formatted sequence and the discrete action codes. + tokens, mask, code_mask = self._tokenize_action(action) # Store mask in complementary data complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if complementary_data is None: complementary_data = {} complementary_data[ACTION_TOKEN_MASK] = mask + complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask complementary_data[ACTION_TOKENS] = tokens new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data return new_transition @@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): """ return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens - def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Tokenizes the action tensor and creates a mask. @@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): # The fast tokenizer expects action data and returns token IDs tokens_list = [] masks_list = [] + code_masks_list = [] for i in range(batch_size): # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) @@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): if tokens.dim() > 1: tokens = tokens.flatten() + action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens) bos_id = self._paligemma_tokenizer.bos_token_id - # add bos + prompt_tokens = torch.tensor( + self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), + device=action.device, + ) + end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device) + + code_start = 1 + len(prompt_tokens) + code_end = code_start + len(action_code_tokens) tokens = torch.cat( [ torch.tensor([bos_id], device=action.device), - torch.tensor( - self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), - device=action.device, - ), - self._act_tokens_to_paligemma_tokens(tokens), - torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device), + prompt_tokens, + action_code_tokens, + end_tokens, ] ) + code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device) + code_mask[code_start:code_end] = True # Truncate or pad to max_action_tokens if len(tokens) > self.max_action_tokens: @@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): "Consider increasing the `max_action_tokens` in your model config if this happens frequently." ) tokens = tokens[: self.max_action_tokens] + code_mask = code_mask[: self.max_action_tokens] mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) else: + pad_len = self.max_action_tokens - len(tokens) mask = torch.cat( [ torch.ones(len(tokens), dtype=torch.bool, device=action.device), - torch.zeros( - self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device - ), + torch.zeros(pad_len, dtype=torch.bool, device=action.device), ] ) + code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False) # Pad tokens with zeros - tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0) + tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0) tokens_list.append(tokens) masks_list.append(mask) + code_masks_list.append(code_mask) # Stack into batched tensors tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) + code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens) # Remove batch dimension if input was single sample if single_sample: tokens_batch = tokens_batch.squeeze(0) masks_batch = masks_batch.squeeze(0) + code_masks_batch = code_masks_batch.squeeze(0) # Move to the same device as the input if device is not None: tokens_batch = tokens_batch.to(device) masks_batch = masks_batch.to(device) + code_masks_batch = code_masks_batch.to(device) - return tokens_batch, masks_batch + return tokens_batch, masks_batch, code_masks_batch def action(self, action: torch.Tensor) -> torch.Tensor: """ This method is not used since we override __call__. Required by ActionProcessorStep ABC. """ - tokens, _ = self._tokenize_action(action) + tokens, _, _ = self._tokenize_action(action) return tokens def get_config(self) -> dict[str, Any]: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43869228d..6d1df20af 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -34,6 +34,7 @@ ACTION = "action" ACTION_PREFIX = ACTION + "." ACTION_TOKENS = ACTION + ".tokens" ACTION_TOKEN_MASK = ACTION + ".token_mask" +ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" diff --git a/tests/policies/pi052/test_pi052_fast_action_loss.py b/tests/policies/pi052/test_pi052_fast_action_loss.py new file mode 100644 index 000000000..9839db28c --- /dev/null +++ b/tests/policies/pi052/test_pi052_fast_action_loss.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression tests for PI052 FAST action-code supervision.""" + +import pytest +import torch +from torch.nn import functional as F + +pytest.importorskip("transformers") + +from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402 + + +def test_fast_ce_supervises_only_discrete_action_codes(): + """Wrapper tokens can be wrong without affecting the FAST action-code loss.""" + vocab_size = 8 + action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]]) + action_code_mask = torch.tensor([[False, False, True, True, False, False]]) + + logits = torch.zeros(1, action_tokens.shape[1], vocab_size) + # Deliberately bad wrapper-token predictions. These should be ignored. + logits[0, 0, 7] = 10.0 # target would be token 2 + logits[0, 3, 7] = 10.0 # target would be delimiter token 5 + # Correct action-code predictions: hidden t predicts target t + 1. + logits[0, 1, 3] = 10.0 + logits[0, 2, 4] = 10.0 + + loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None) + expected = F.cross_entropy( + torch.stack([logits[0, 1], logits[0, 2]]), + torch.tensor([3, 4]), + reduction="mean", + ) + + assert torch.allclose(loss, expected) + + +def test_fast_ce_masks_non_action_samples(): + """Recipe samples with predict_actions=False do not contribute FAST loss.""" + vocab_size = 8 + action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]]) + action_code_mask = torch.tensor( + [[False, False, True, True], [False, False, True, True]] + ) + predict_actions = torch.tensor([True, False]) + + logits = torch.zeros(2, action_tokens.shape[1], vocab_size) + logits[0, 1, 3] = 10.0 + logits[0, 2, 4] = 10.0 + # Bad predictions in the masked sample should not matter. + logits[1, 1, 7] = 10.0 + logits[1, 2, 7] = 10.0 + + loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions) + expected = F.cross_entropy( + torch.stack([logits[0, 1], logits[0, 2]]), + torch.tensor([3, 4]), + reduction="mean", + ) + + assert torch.allclose(loss, expected)