mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(pi052): supervise only FAST action-code tokens
Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -124,19 +124,19 @@ def _mark_target_span_causal(
|
||||
def _fast_ce(
|
||||
fast_logits: Tensor,
|
||||
action_tokens: Tensor,
|
||||
action_mask: Tensor,
|
||||
action_code_mask: Tensor,
|
||||
predict_actions_t: Tensor | None,
|
||||
) -> Tensor:
|
||||
"""FAST-CE with both token-pad masking and per-sample action gating.
|
||||
"""FAST action-code CE with token-span masking and per-sample action gating.
|
||||
|
||||
``action_mask`` is the FAST tokenizer's padding mask; samples whose
|
||||
recipe sets ``predict_actions=False`` (e.g. plan_generation,
|
||||
ask_vqa_*) get *all* their FAST positions masked out via the
|
||||
per-sample gate.
|
||||
``action_code_mask`` is true only on the discrete action-code tokens,
|
||||
excluding the BOS / "Action: " / delimiter wrapper. Samples whose
|
||||
recipe sets ``predict_actions=False`` get all code positions masked
|
||||
out via the per-sample gate.
|
||||
"""
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_valid = action_code_mask[:, 1:].contiguous().bool()
|
||||
if predict_actions_t is not None:
|
||||
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
||||
shift_valid = shift_valid & sample_mask
|
||||
@@ -429,13 +429,18 @@ class PI052Policy(PI05Policy):
|
||||
and self.config.fast_action_loss_weight > 0
|
||||
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
||||
)
|
||||
action_tokens = action_mask = None
|
||||
action_tokens = action_mask = action_code_mask = None
|
||||
if run_fast:
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
)
|
||||
|
||||
action_tokens = batch.get(ACTION_TOKENS)
|
||||
action_mask = batch.get(ACTION_TOKEN_MASK)
|
||||
if action_tokens is None or action_mask is None:
|
||||
action_code_mask = batch.get(ACTION_CODE_TOKEN_MASK)
|
||||
if action_tokens is None or action_mask is None or action_code_mask is None:
|
||||
run_fast = False
|
||||
|
||||
# ------------------------------------------------------------
|
||||
@@ -457,6 +462,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
action_code_mask=action_code_mask if run_fast else None,
|
||||
predict_actions_t=predict_actions_t,
|
||||
)
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
@@ -473,6 +479,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
action_code_mask=action_code_mask if run_fast else None,
|
||||
predict_actions_t=predict_actions_t,
|
||||
)
|
||||
if text_loss is not None:
|
||||
@@ -508,6 +515,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
action_code_mask: Tensor | None,
|
||||
predict_actions_t: Tensor | None = None,
|
||||
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
||||
@@ -647,10 +655,10 @@ class PI052Policy(PI05Policy):
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if fast_len > 0 and prefix_out is not None:
|
||||
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
|
||||
fast_hidden = prefix_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
||||
|
||||
return flow_loss, text_loss, fast_loss
|
||||
|
||||
@@ -660,6 +668,7 @@ class PI052Policy(PI05Policy):
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
action_code_mask: Tensor | None,
|
||||
predict_actions_t: Tensor | None = None,
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Single prefix forward → text CE + FAST CE.
|
||||
@@ -745,10 +754,14 @@ class PI052Policy(PI05Policy):
|
||||
text_loss = _shifted_ce(text_logits, text_labels)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if action_tokens is not None and action_mask is not None and fast_len > 0:
|
||||
if (
|
||||
action_tokens is not None
|
||||
and action_code_mask is not None
|
||||
and fast_len > 0
|
||||
):
|
||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_mask, predict_actions_t)
|
||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
||||
|
||||
return text_loss, fast_loss
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ import torch
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# During inference, no action is available, skip tokenization
|
||||
return new_transition
|
||||
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||
tokens, mask, code_mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
code_masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# add bos
|
||||
prompt_tokens = torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
)
|
||||
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||
|
||||
code_start = 1 + len(prompt_tokens)
|
||||
code_end = code_start + len(action_code_tokens)
|
||||
tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
prompt_tokens,
|
||||
action_code_tokens,
|
||||
end_tokens,
|
||||
]
|
||||
)
|
||||
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||
code_mask[code_start:code_end] = True
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self.max_action_tokens]
|
||||
code_mask = code_mask[: self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
pad_len = self.max_action_tokens - len(tokens)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(
|
||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
||||
),
|
||||
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||
]
|
||||
)
|
||||
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
||||
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
code_masks_list.append(code_mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
code_masks_batch = code_masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
code_masks_batch = code_masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch
|
||||
return tokens_batch, masks_batch, code_masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
tokens, _, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -34,6 +34,7 @@ ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Regression tests for PI052 FAST action-code supervision."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
|
||||
|
||||
|
||||
def test_fast_ce_supervises_only_discrete_action_codes():
|
||||
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
|
||||
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
|
||||
|
||||
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
|
||||
# Deliberately bad wrapper-token predictions. These should be ignored.
|
||||
logits[0, 0, 7] = 10.0 # target would be token 2
|
||||
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
|
||||
# Correct action-code predictions: hidden t predicts target t + 1.
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
|
||||
|
||||
def test_fast_ce_masks_non_action_samples():
|
||||
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||
action_code_mask = torch.tensor(
|
||||
[[False, False, True, True], [False, False, True, True]]
|
||||
)
|
||||
predict_actions = torch.tensor([True, False])
|
||||
|
||||
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
# Bad predictions in the masked sample should not matter.
|
||||
logits[1, 1, 7] = 10.0
|
||||
logits[1, 2, 7] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
Reference in New Issue
Block a user