fix(pi052): handle batched rendered messages

Tokenize batched recipe outputs in PI052 so training batches with nested message lists do not crash before model forward.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 17:41:58 +00:00
parent 0e2dc1b76f
commit 1750a87104
2 changed files with 176 additions and 38 deletions
@@ -21,7 +21,11 @@ PaliGemma's flat prompt has no structured tool calls, so an assistant
marker — otherwise the spoken reply is dropped and never supervised.
"""
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
import torch
from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
def _say_call(text):
@@ -58,3 +62,66 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
)
assert out["content"] == "plan only"
assert "tool_calls" not in out
class _CharTokenizer:
pad_token_id = 0
def __call__(
self,
text,
max_length,
padding,
truncation,
return_tensors,
return_offsets_mapping,
padding_side,
):
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
offsets = [(i, i + 1) for i in range(len(ids))]
attention = [1] * len(ids)
if padding == "max_length" and len(ids) < max_length:
pad = max_length - len(ids)
ids += [self.pad_token_id] * pad
offsets += [(0, 0)] * pad
attention += [0] * pad
return {
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention], dtype=torch.long),
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
}
def decode(self, token_ids, skip_special_tokens=False):
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
[
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
[{"role": "user", "content": "open drawer"}],
],
"target_message_indices": [[1], []],
"message_streams": [["high_level", "high_level"], ["low_level"]],
"index": torch.tensor([10, 11]),
},
}
out = step(transition)
obs = out[TransitionKey.OBSERVATION]
comp = out[TransitionKey.COMPLEMENTARY_DATA]
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
assert comp["text_labels"].shape == (2, 64)
assert comp["predict_actions"].tolist() == [False, True]
assert (comp["text_labels"][0] != -100).any()
assert not (comp["text_labels"][1] != -100).any()