From 1750a87104015dc665e28d79eae3f8768f645833 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 18 May 2026 17:41:58 +0000 Subject: [PATCH] fix(pi052): handle batched rendered messages Tokenize batched recipe outputs in PI052 so training batches with nested message lists do not crash before model forward. Co-authored-by: Cursor --- .../policies/pi052/text_processor_pi052.py | 145 +++++++++++++----- .../pi052/test_pi052_text_processor.py | 69 ++++++++- 2 files changed, 176 insertions(+), 38 deletions(-) diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index 38a4d082e..559e51849 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -42,6 +42,7 @@ from dataclasses import dataclass from typing import Any import torch +from torch import Tensor from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry @@ -214,6 +215,25 @@ def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]: return new +def _is_batched_messages(messages: Any) -> bool: + return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) + + +def _sample_indices(value: Any, batch_size: int) -> list[int | None]: + if value is None: + return [None] * batch_size + if isinstance(value, torch.Tensor): + if value.numel() == 1: + return [int(value.item())] * batch_size + values = value.reshape(-1).tolist() + return [int(v) for v in values[:batch_size]] + if isinstance(value, (list, tuple)): + if len(value) == 1: + return _sample_indices(value[0], batch_size) + return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]] + return [int(value)] * batch_size + + def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]: """Concatenate messages into the π0.5-style flat prompt. @@ -285,8 +305,6 @@ class PI052TextTokenizerStep(ProcessorStep): transition = transition.copy() complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} messages = complementary.get("messages") or [] - target_indices = list(complementary.get("target_message_indices") or []) - message_streams = list(complementary.get("message_streams") or []) if not messages: # No recipe was rendered — caller will fall back to the @@ -294,6 +312,90 @@ class PI052TextTokenizerStep(ProcessorStep): # unmodified. return transition + tokenizer = self._ensure_tokenizer() + if _is_batched_messages(messages): + indices_iter = _sample_indices(complementary.get("index"), len(messages)) + encoded = [ + self._encode_messages( + tokenizer, + msg, + list(streams), + list(tgt_indices), + complementary, + sample_idx=int(s_idx) if s_idx is not None else None, + ) + for msg, streams, tgt_indices, s_idx in zip( + messages, + complementary.get("message_streams") or [[] for _ in messages], + complementary.get("target_message_indices") or [[] for _ in messages], + indices_iter, + strict=False, + ) + ] + else: + sample_idx = _sample_indices(complementary.get("index"), 1)[0] + encoded = [ + self._encode_messages( + tokenizer, + messages, + list(complementary.get("message_streams") or []), + list(complementary.get("target_message_indices") or []), + complementary, + sample_idx=sample_idx, + ) + ] + + if _DUMP_BUDGET > 0: + if _is_batched_messages(messages): + msgs_iter = messages + streams_iter = complementary.get("message_streams") or [[] for _ in messages] + targets_iter = complementary.get("target_message_indices") or [[] for _ in messages] + else: + msgs_iter = [messages] + streams_iter = [list(complementary.get("message_streams") or [])] + targets_iter = [list(complementary.get("target_message_indices") or [])] + for msg, streams, targets, (ids, attn, labels, predict_action, prompt) in zip( + msgs_iter, streams_iter, targets_iter, encoded, strict=False + ): + target_set = {int(i) for i in targets} + annotated_msgs = [ + { + **m, + "stream": streams[i] if i < len(streams) else None, + "target": True if i in target_set else None, + } + for i, m in enumerate(msg) + ] + _dump_recipe_sample( + messages=annotated_msgs, + prompt_text=prompt, + token_ids=ids.tolist(), + labels=labels.tolist(), + predict_actions=bool(predict_action.item()), + tokenizer=tokenizer, + ) + + obs = dict(transition.get(TransitionKey.OBSERVATION) or {}) + obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded]) + obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded]) + transition[TransitionKey.OBSERVATION] = obs + + transition[TransitionKey.COMPLEMENTARY_DATA] = { + **complementary, + "text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]), + "predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]), + } + return transition + + def _encode_messages( + self, + tokenizer: Any, + messages: list[dict[str, Any]], + message_streams: list[str | None], + target_indices: list[int], + complementary: dict[str, Any], + sample_idx: int | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, str]: # Optional: drop non-target messages per the dropout config. # Keeps the supervised-target indices stable by re-mapping # after removal. @@ -307,6 +409,7 @@ class PI052TextTokenizerStep(ProcessorStep): messages, target_indices, complementary, + sample_idx=sample_idx, ) # Flatten ``say`` tool calls into ``...`` text before @@ -315,7 +418,6 @@ class PI052TextTokenizerStep(ProcessorStep): messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages] prompt, spans = _format_messages(messages) - tokenizer = self._ensure_tokenizer() encoded = tokenizer( prompt, max_length=self.max_length, @@ -354,39 +456,7 @@ class PI052TextTokenizerStep(ProcessorStep): bool(any(s == "low_level" for s in message_streams)), dtype=torch.bool, ) - - if _DUMP_BUDGET > 0: - # Stream / target metadata live in parallel arrays; zip them - # back into the dicts so the dump shows them per message. - target_set = {int(i) for i in target_indices} - annotated_msgs = [ - { - **m, - "stream": message_streams[i] if i < len(message_streams) else None, - "target": True if i in target_set else None, - } - for i, m in enumerate(messages) - ] - _dump_recipe_sample( - messages=annotated_msgs, - prompt_text=prompt, - token_ids=input_ids.tolist(), - labels=labels.tolist(), - predict_actions=bool(predict_actions.item()), - tokenizer=tokenizer, - ) - - obs = dict(transition.get(TransitionKey.OBSERVATION) or {}) - obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0) - obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0) - transition[TransitionKey.OBSERVATION] = obs - - transition[TransitionKey.COMPLEMENTARY_DATA] = { - **complementary, - "text_labels": labels.unsqueeze(0), - "predict_actions": predict_actions.unsqueeze(0), - } - return transition + return input_ids, attention_mask, labels, predict_actions, prompt # ------------------------------------------------------------------ # Per-component prompt dropout (Pi0.7 §V.E) @@ -397,6 +467,7 @@ class PI052TextTokenizerStep(ProcessorStep): messages: list[dict[str, Any]], target_indices: list[int], complementary: dict[str, Any], + sample_idx: int | None = None, ) -> tuple[list[dict[str, Any]], list[int]]: """Drop messages classified as plan/memory/subtask context. @@ -411,7 +482,7 @@ class PI052TextTokenizerStep(ProcessorStep): # ``render_messages_processor``. Falling back to other # keys silently gave every sample seed=0 → identical # dropout pattern across the whole epoch. - seed_src = complementary.get("index", 0) + seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0) try: if hasattr(seed_src, "item"): seed_src = seed_src.item() diff --git a/tests/policies/pi052/test_pi052_text_processor.py b/tests/policies/pi052/test_pi052_text_processor.py index 918582845..9547c2a20 100644 --- a/tests/policies/pi052/test_pi052_text_processor.py +++ b/tests/policies/pi052/test_pi052_text_processor.py @@ -21,7 +21,11 @@ PaliGemma's flat prompt has no structured tool calls, so an assistant marker — otherwise the spoken reply is dropped and never supervised. """ -from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls +import torch + +from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls +from lerobot.types import TransitionKey +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS def _say_call(text): @@ -58,3 +62,66 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content(): ) assert out["content"] == "plan only" assert "tool_calls" not in out + + +class _CharTokenizer: + pad_token_id = 0 + + def __call__( + self, + text, + max_length, + padding, + truncation, + return_tensors, + return_offsets_mapping, + padding_side, + ): + ids = [ord(c) % 251 + 1 for c in text[:max_length]] + offsets = [(i, i + 1) for i in range(len(ids))] + attention = [1] * len(ids) + if padding == "max_length" and len(ids) < max_length: + pad = max_length - len(ids) + ids += [self.pad_token_id] * pad + offsets += [(0, 0)] * pad + attention += [0] * pad + return { + "input_ids": torch.tensor([ids], dtype=torch.long), + "attention_mask": torch.tensor([attention], dtype=torch.long), + "offset_mapping": torch.tensor([offsets], dtype=torch.long), + } + + def decode(self, token_ids, skip_special_tokens=False): + return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id) + + +def test_pi052_text_tokenizer_handles_batched_rendered_messages(): + step = PI052TextTokenizerStep(max_length=64) + step._tokenizer = _CharTokenizer() + + transition = { + TransitionKey.OBSERVATION: {}, + TransitionKey.COMPLEMENTARY_DATA: { + "messages": [ + [ + {"role": "user", "content": "pick cube"}, + {"role": "assistant", "content": "move to cube"}, + ], + [{"role": "user", "content": "open drawer"}], + ], + "target_message_indices": [[1], []], + "message_streams": [["high_level", "high_level"], ["low_level"]], + "index": torch.tensor([10, 11]), + }, + } + + out = step(transition) + obs = out[TransitionKey.OBSERVATION] + comp = out[TransitionKey.COMPLEMENTARY_DATA] + + assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64) + assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64) + assert comp["text_labels"].shape == (2, 64) + assert comp["predict_actions"].tolist() == [False, True] + assert (comp["text_labels"][0] != -100).any() + assert not (comp["text_labels"][1] != -100).any()