mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix(pi052): handle batched rendered messages
Tokenize batched recipe outputs in PI052 so training batches with nested message lists do not crash before model forward. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||||
@@ -214,6 +215,25 @@ def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
def _is_batched_messages(messages: Any) -> bool:
|
||||||
|
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||||
|
if value is None:
|
||||||
|
return [None] * batch_size
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if value.numel() == 1:
|
||||||
|
return [int(value.item())] * batch_size
|
||||||
|
values = value.reshape(-1).tolist()
|
||||||
|
return [int(v) for v in values[:batch_size]]
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
if len(value) == 1:
|
||||||
|
return _sample_indices(value[0], batch_size)
|
||||||
|
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||||
|
return [int(value)] * batch_size
|
||||||
|
|
||||||
|
|
||||||
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
|
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
|
||||||
"""Concatenate messages into the π0.5-style flat prompt.
|
"""Concatenate messages into the π0.5-style flat prompt.
|
||||||
|
|
||||||
@@ -285,8 +305,6 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
transition = transition.copy()
|
transition = transition.copy()
|
||||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||||
messages = complementary.get("messages") or []
|
messages = complementary.get("messages") or []
|
||||||
target_indices = list(complementary.get("target_message_indices") or [])
|
|
||||||
message_streams = list(complementary.get("message_streams") or [])
|
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
# No recipe was rendered — caller will fall back to the
|
# No recipe was rendered — caller will fall back to the
|
||||||
@@ -294,6 +312,90 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
# unmodified.
|
# unmodified.
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
tokenizer = self._ensure_tokenizer()
|
||||||
|
if _is_batched_messages(messages):
|
||||||
|
indices_iter = _sample_indices(complementary.get("index"), len(messages))
|
||||||
|
encoded = [
|
||||||
|
self._encode_messages(
|
||||||
|
tokenizer,
|
||||||
|
msg,
|
||||||
|
list(streams),
|
||||||
|
list(tgt_indices),
|
||||||
|
complementary,
|
||||||
|
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||||
|
)
|
||||||
|
for msg, streams, tgt_indices, s_idx in zip(
|
||||||
|
messages,
|
||||||
|
complementary.get("message_streams") or [[] for _ in messages],
|
||||||
|
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||||
|
indices_iter,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
|
||||||
|
encoded = [
|
||||||
|
self._encode_messages(
|
||||||
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
list(complementary.get("message_streams") or []),
|
||||||
|
list(complementary.get("target_message_indices") or []),
|
||||||
|
complementary,
|
||||||
|
sample_idx=sample_idx,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if _DUMP_BUDGET > 0:
|
||||||
|
if _is_batched_messages(messages):
|
||||||
|
msgs_iter = messages
|
||||||
|
streams_iter = complementary.get("message_streams") or [[] for _ in messages]
|
||||||
|
targets_iter = complementary.get("target_message_indices") or [[] for _ in messages]
|
||||||
|
else:
|
||||||
|
msgs_iter = [messages]
|
||||||
|
streams_iter = [list(complementary.get("message_streams") or [])]
|
||||||
|
targets_iter = [list(complementary.get("target_message_indices") or [])]
|
||||||
|
for msg, streams, targets, (ids, attn, labels, predict_action, prompt) in zip(
|
||||||
|
msgs_iter, streams_iter, targets_iter, encoded, strict=False
|
||||||
|
):
|
||||||
|
target_set = {int(i) for i in targets}
|
||||||
|
annotated_msgs = [
|
||||||
|
{
|
||||||
|
**m,
|
||||||
|
"stream": streams[i] if i < len(streams) else None,
|
||||||
|
"target": True if i in target_set else None,
|
||||||
|
}
|
||||||
|
for i, m in enumerate(msg)
|
||||||
|
]
|
||||||
|
_dump_recipe_sample(
|
||||||
|
messages=annotated_msgs,
|
||||||
|
prompt_text=prompt,
|
||||||
|
token_ids=ids.tolist(),
|
||||||
|
labels=labels.tolist(),
|
||||||
|
predict_actions=bool(predict_action.item()),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||||
|
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
|
||||||
|
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
|
||||||
|
transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
|
||||||
|
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
||||||
|
**complementary,
|
||||||
|
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
|
||||||
|
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
|
||||||
|
}
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def _encode_messages(
|
||||||
|
self,
|
||||||
|
tokenizer: Any,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
message_streams: list[str | None],
|
||||||
|
target_indices: list[int],
|
||||||
|
complementary: dict[str, Any],
|
||||||
|
sample_idx: int | None = None,
|
||||||
|
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||||
# Optional: drop non-target messages per the dropout config.
|
# Optional: drop non-target messages per the dropout config.
|
||||||
# Keeps the supervised-target indices stable by re-mapping
|
# Keeps the supervised-target indices stable by re-mapping
|
||||||
# after removal.
|
# after removal.
|
||||||
@@ -307,6 +409,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
messages,
|
messages,
|
||||||
target_indices,
|
target_indices,
|
||||||
complementary,
|
complementary,
|
||||||
|
sample_idx=sample_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
||||||
@@ -315,7 +418,6 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||||
prompt, spans = _format_messages(messages)
|
prompt, spans = _format_messages(messages)
|
||||||
|
|
||||||
tokenizer = self._ensure_tokenizer()
|
|
||||||
encoded = tokenizer(
|
encoded = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
@@ -354,39 +456,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
bool(any(s == "low_level" for s in message_streams)),
|
bool(any(s == "low_level" for s in message_streams)),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
)
|
)
|
||||||
|
return input_ids, attention_mask, labels, predict_actions, prompt
|
||||||
if _DUMP_BUDGET > 0:
|
|
||||||
# Stream / target metadata live in parallel arrays; zip them
|
|
||||||
# back into the dicts so the dump shows them per message.
|
|
||||||
target_set = {int(i) for i in target_indices}
|
|
||||||
annotated_msgs = [
|
|
||||||
{
|
|
||||||
**m,
|
|
||||||
"stream": message_streams[i] if i < len(message_streams) else None,
|
|
||||||
"target": True if i in target_set else None,
|
|
||||||
}
|
|
||||||
for i, m in enumerate(messages)
|
|
||||||
]
|
|
||||||
_dump_recipe_sample(
|
|
||||||
messages=annotated_msgs,
|
|
||||||
prompt_text=prompt,
|
|
||||||
token_ids=input_ids.tolist(),
|
|
||||||
labels=labels.tolist(),
|
|
||||||
predict_actions=bool(predict_actions.item()),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
|
||||||
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
|
|
||||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
|
|
||||||
transition[TransitionKey.OBSERVATION] = obs
|
|
||||||
|
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
|
||||||
**complementary,
|
|
||||||
"text_labels": labels.unsqueeze(0),
|
|
||||||
"predict_actions": predict_actions.unsqueeze(0),
|
|
||||||
}
|
|
||||||
return transition
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Per-component prompt dropout (Pi0.7 §V.E)
|
# Per-component prompt dropout (Pi0.7 §V.E)
|
||||||
@@ -397,6 +467,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
target_indices: list[int],
|
target_indices: list[int],
|
||||||
complementary: dict[str, Any],
|
complementary: dict[str, Any],
|
||||||
|
sample_idx: int | None = None,
|
||||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||||
"""Drop messages classified as plan/memory/subtask context.
|
"""Drop messages classified as plan/memory/subtask context.
|
||||||
|
|
||||||
@@ -411,7 +482,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
# ``render_messages_processor``. Falling back to other
|
# ``render_messages_processor``. Falling back to other
|
||||||
# keys silently gave every sample seed=0 → identical
|
# keys silently gave every sample seed=0 → identical
|
||||||
# dropout pattern across the whole epoch.
|
# dropout pattern across the whole epoch.
|
||||||
seed_src = complementary.get("index", 0)
|
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
|
||||||
try:
|
try:
|
||||||
if hasattr(seed_src, "item"):
|
if hasattr(seed_src, "item"):
|
||||||
seed_src = seed_src.item()
|
seed_src = seed_src.item()
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ PaliGemma's flat prompt has no structured tool calls, so an assistant
|
|||||||
marker — otherwise the spoken reply is dropped and never supervised.
|
marker — otherwise the spoken reply is dropped and never supervised.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
|
||||||
|
from lerobot.types import TransitionKey
|
||||||
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
|
||||||
def _say_call(text):
|
def _say_call(text):
|
||||||
@@ -58,3 +62,66 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
|
|||||||
)
|
)
|
||||||
assert out["content"] == "plan only"
|
assert out["content"] == "plan only"
|
||||||
assert "tool_calls" not in out
|
assert "tool_calls" not in out
|
||||||
|
|
||||||
|
|
||||||
|
class _CharTokenizer:
|
||||||
|
pad_token_id = 0
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text,
|
||||||
|
max_length,
|
||||||
|
padding,
|
||||||
|
truncation,
|
||||||
|
return_tensors,
|
||||||
|
return_offsets_mapping,
|
||||||
|
padding_side,
|
||||||
|
):
|
||||||
|
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
|
||||||
|
offsets = [(i, i + 1) for i in range(len(ids))]
|
||||||
|
attention = [1] * len(ids)
|
||||||
|
if padding == "max_length" and len(ids) < max_length:
|
||||||
|
pad = max_length - len(ids)
|
||||||
|
ids += [self.pad_token_id] * pad
|
||||||
|
offsets += [(0, 0)] * pad
|
||||||
|
attention += [0] * pad
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor([ids], dtype=torch.long),
|
||||||
|
"attention_mask": torch.tensor([attention], dtype=torch.long),
|
||||||
|
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
|
||||||
|
}
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=False):
|
||||||
|
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
|
||||||
|
step = PI052TextTokenizerStep(max_length=64)
|
||||||
|
step._tokenizer = _CharTokenizer()
|
||||||
|
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"messages": [
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "pick cube"},
|
||||||
|
{"role": "assistant", "content": "move to cube"},
|
||||||
|
],
|
||||||
|
[{"role": "user", "content": "open drawer"}],
|
||||||
|
],
|
||||||
|
"target_message_indices": [[1], []],
|
||||||
|
"message_streams": [["high_level", "high_level"], ["low_level"]],
|
||||||
|
"index": torch.tensor([10, 11]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out = step(transition)
|
||||||
|
obs = out[TransitionKey.OBSERVATION]
|
||||||
|
comp = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
|
||||||
|
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
|
||||||
|
assert comp["text_labels"].shape == (2, 64)
|
||||||
|
assert comp["predict_actions"].tolist() == [False, True]
|
||||||
|
assert (comp["text_labels"][0] != -100).any()
|
||||||
|
assert not (comp["text_labels"][1] != -100).any()
|
||||||
|
|||||||
Reference in New Issue
Block a user