fix(pi052): handle batched rendered messages

Tokenize batched recipe outputs in PI052 so training batches with nested message lists do not crash before model forward.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 17:41:58 +00:00
parent 0e2dc1b76f
commit 1750a87104
2 changed files with 176 additions and 38 deletions
@@ -42,6 +42,7 @@ from dataclasses import dataclass
from typing import Any from typing import Any
import torch import torch
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
@@ -214,6 +215,25 @@ def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
return new return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]: def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt. """Concatenate messages into the π0.5-style flat prompt.
@@ -285,8 +305,6 @@ class PI052TextTokenizerStep(ProcessorStep):
transition = transition.copy() transition = transition.copy()
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
messages = complementary.get("messages") or [] messages = complementary.get("messages") or []
target_indices = list(complementary.get("target_message_indices") or [])
message_streams = list(complementary.get("message_streams") or [])
if not messages: if not messages:
# No recipe was rendered — caller will fall back to the # No recipe was rendered — caller will fall back to the
@@ -294,6 +312,90 @@ class PI052TextTokenizerStep(ProcessorStep):
# unmodified. # unmodified.
return transition return transition
tokenizer = self._ensure_tokenizer()
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
]
else:
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(complementary.get("message_streams") or []),
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
)
]
if _DUMP_BUDGET > 0:
if _is_batched_messages(messages):
msgs_iter = messages
streams_iter = complementary.get("message_streams") or [[] for _ in messages]
targets_iter = complementary.get("target_message_indices") or [[] for _ in messages]
else:
msgs_iter = [messages]
streams_iter = [list(complementary.get("message_streams") or [])]
targets_iter = [list(complementary.get("target_message_indices") or [])]
for msg, streams, targets, (ids, attn, labels, predict_action, prompt) in zip(
msgs_iter, streams_iter, targets_iter, encoded, strict=False
):
target_set = {int(i) for i in targets}
annotated_msgs = [
{
**m,
"stream": streams[i] if i < len(streams) else None,
"target": True if i in target_set else None,
}
for i, m in enumerate(msg)
]
_dump_recipe_sample(
messages=annotated_msgs,
prompt_text=prompt,
token_ids=ids.tolist(),
labels=labels.tolist(),
predict_actions=bool(predict_action.item()),
tokenizer=tokenizer,
)
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
}
return transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config. # Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping # Keeps the supervised-target indices stable by re-mapping
# after removal. # after removal.
@@ -307,6 +409,7 @@ class PI052TextTokenizerStep(ProcessorStep):
messages, messages,
target_indices, target_indices,
complementary, complementary,
sample_idx=sample_idx,
) )
# Flatten ``say`` tool calls into ``<say>...</say>`` text before # Flatten ``say`` tool calls into ``<say>...</say>`` text before
@@ -315,7 +418,6 @@ class PI052TextTokenizerStep(ProcessorStep):
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages] messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
prompt, spans = _format_messages(messages) prompt, spans = _format_messages(messages)
tokenizer = self._ensure_tokenizer()
encoded = tokenizer( encoded = tokenizer(
prompt, prompt,
max_length=self.max_length, max_length=self.max_length,
@@ -354,39 +456,7 @@ class PI052TextTokenizerStep(ProcessorStep):
bool(any(s == "low_level" for s in message_streams)), bool(any(s == "low_level" for s in message_streams)),
dtype=torch.bool, dtype=torch.bool,
) )
return input_ids, attention_mask, labels, predict_actions, prompt
if _DUMP_BUDGET > 0:
# Stream / target metadata live in parallel arrays; zip them
# back into the dicts so the dump shows them per message.
target_set = {int(i) for i in target_indices}
annotated_msgs = [
{
**m,
"stream": message_streams[i] if i < len(message_streams) else None,
"target": True if i in target_set else None,
}
for i, m in enumerate(messages)
]
_dump_recipe_sample(
messages=annotated_msgs,
prompt_text=prompt,
token_ids=input_ids.tolist(),
labels=labels.tolist(),
predict_actions=bool(predict_actions.item()),
tokenizer=tokenizer,
)
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": labels.unsqueeze(0),
"predict_actions": predict_actions.unsqueeze(0),
}
return transition
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Per-component prompt dropout (Pi0.7 §V.E) # Per-component prompt dropout (Pi0.7 §V.E)
@@ -397,6 +467,7 @@ class PI052TextTokenizerStep(ProcessorStep):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
target_indices: list[int], target_indices: list[int],
complementary: dict[str, Any], complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[list[dict[str, Any]], list[int]]: ) -> tuple[list[dict[str, Any]], list[int]]:
"""Drop messages classified as plan/memory/subtask context. """Drop messages classified as plan/memory/subtask context.
@@ -411,7 +482,7 @@ class PI052TextTokenizerStep(ProcessorStep):
# ``render_messages_processor``. Falling back to other # ``render_messages_processor``. Falling back to other
# keys silently gave every sample seed=0 → identical # keys silently gave every sample seed=0 → identical
# dropout pattern across the whole epoch. # dropout pattern across the whole epoch.
seed_src = complementary.get("index", 0) seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
try: try:
if hasattr(seed_src, "item"): if hasattr(seed_src, "item"):
seed_src = seed_src.item() seed_src = seed_src.item()
@@ -21,7 +21,11 @@ PaliGemma's flat prompt has no structured tool calls, so an assistant
marker otherwise the spoken reply is dropped and never supervised. marker otherwise the spoken reply is dropped and never supervised.
""" """
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls import torch
from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
def _say_call(text): def _say_call(text):
@@ -58,3 +62,66 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
) )
assert out["content"] == "plan only" assert out["content"] == "plan only"
assert "tool_calls" not in out assert "tool_calls" not in out
class _CharTokenizer:
pad_token_id = 0
def __call__(
self,
text,
max_length,
padding,
truncation,
return_tensors,
return_offsets_mapping,
padding_side,
):
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
offsets = [(i, i + 1) for i in range(len(ids))]
attention = [1] * len(ids)
if padding == "max_length" and len(ids) < max_length:
pad = max_length - len(ids)
ids += [self.pad_token_id] * pad
offsets += [(0, 0)] * pad
attention += [0] * pad
return {
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention], dtype=torch.long),
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
}
def decode(self, token_ids, skip_special_tokens=False):
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
[
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
[{"role": "user", "content": "open drawer"}],
],
"target_message_indices": [[1], []],
"message_streams": [["high_level", "high_level"], ["low_level"]],
"index": torch.tensor([10, 11]),
},
}
out = step(transition)
obs = out[TransitionKey.OBSERVATION]
comp = out[TransitionKey.COMPLEMENTARY_DATA]
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
assert comp["text_labels"].shape == (2, 64)
assert comp["predict_actions"].tolist() == [False, True]
assert (comp["text_labels"][0] != -100).any()
assert not (comp["text_labels"][1] != -100).any()