diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index 559e51849..6ebd54168 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -234,21 +234,33 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]: return [int(value)] * batch_size -def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]: +def _format_messages( + messages: list[dict[str, Any]], + target_indices: list[int] | None = None, + eos_token: str | None = None, +) -> tuple[str, list[tuple[int, int]]]: """Concatenate messages into the π0.5-style flat prompt. + When both ``target_indices`` and ``eos_token`` are given, the EOS + string is appended to each supervised target turn's content and the + returned span covers it — so the label builder marks the EOS token + as a supervised label. That teaches the LM head where the answer + *ends*: without an EOS in the target span the model is never given a + stop signal and rambles to ``max_length`` at inference. Inference + callers omit both args (no EOS baked into the prompt — the model + generates it and ``select_message`` stops on it). + Returns: prompt: the full text the tokenizer will consume. msg_spans: list of ``(char_start, char_end)`` covering each - message's content within ``prompt``. The - target-mask builder uses this to find the - character ranges belonging to the supervised - messages. + message's supervised payload (content, plus the + appended EOS for target turns) within ``prompt``. """ + targets = set(target_indices or []) parts: list[str] = [] spans: list[tuple[int, int]] = [] cursor = 0 - for m in messages: + for i, m in enumerate(messages): role = m.get("role", "user") content = m.get("content", "") or "" # Role tag + newline. The model has to learn to emit the same @@ -256,11 +268,15 @@ def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[in # decoding because the chat template is implicit in the # supervised target span. header = f"{role.capitalize()}: " - # span covers ONLY the content portion (so labels are computed - # over the supervised payload, not the role tag). - full = header + content + "\n" + # A supervised target turn ends with EOS so the model learns to + # terminate; the span below covers content + EOS. Non-target + # turns (and inference) carry no EOS. + body = content + eos_token if (eos_token and i in targets) else content + # span covers the content (+ EOS) portion only — never the role + # tag — so labels are computed over the supervised payload. + full = header + body + "\n" start = cursor + len(header) - end = start + len(content) + end = start + len(body) parts.append(full) spans.append((start, end)) cursor += len(full) @@ -416,7 +432,11 @@ class PI052TextTokenizerStep(ProcessorStep): # stripping, so the spoken reply is actually tokenized and # supervised (PaliGemma's flat prompt has no structured calls). messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages] - prompt, spans = _format_messages(messages) + # Append EOS to supervised target turns so the LM head learns to + # stop (the span covers it → it becomes a supervised label). + prompt, spans = _format_messages( + messages, target_indices, getattr(tokenizer, "eos_token", None) + ) encoded = tokenizer( prompt, diff --git a/tests/policies/pi052/test_pi052_text_processor.py b/tests/policies/pi052/test_pi052_text_processor.py index 9547c2a20..77695e12e 100644 --- a/tests/policies/pi052/test_pi052_text_processor.py +++ b/tests/policies/pi052/test_pi052_text_processor.py @@ -14,16 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for PI052's text-tokenizer ``say`` tool-call flattening. +"""Tests for PI052's text tokenizer. -PaliGemma's flat prompt has no structured tool calls, so an assistant -``say`` tool call must be serialized into a ``...`` text -marker — otherwise the spoken reply is dropped and never supervised. +Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no +structured tool calls, so a ``say`` call must be serialized into a +``...`` text marker) and EOS-termination supervision (the +supervised target span must end with an EOS token so the LM head learns +to stop instead of rambling to ``max_length`` at inference). """ import torch -from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls +from lerobot.policies.pi052.text_processor_pi052 import ( + PI052TextTokenizerStep, + _flatten_say_tool_calls, + _format_messages, +) from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS @@ -64,8 +70,71 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content(): assert "tool_calls" not in out +# --------------------------------------------------------------------------- +# EOS-termination supervision +# --------------------------------------------------------------------------- + + +def test_format_messages_appends_eos_to_target_turns_only(): + msgs = [ + {"role": "user", "content": "pick cube"}, + {"role": "assistant", "content": "move to cube"}, + ] + prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="") + # EOS is appended to the supervised target (assistant) turn only. + assert prompt == "User: pick cube\nAssistant: move to cube\n" + # The user span is unchanged; the target span covers content + EOS. + assert prompt[spans[0][0] : spans[0][1]] == "pick cube" + assert prompt[spans[1][0] : spans[1][1]] == "move to cube" + + +def test_format_messages_without_eos_args_is_unchanged(): + """Inference callers omit target_indices / eos_token — no EOS baked in.""" + prompt, spans = _format_messages([{"role": "user", "content": "hi"}]) + assert prompt == "User: hi\n" + assert prompt[spans[0][0] : spans[0][1]] == "hi" + + +def _eos_char_id() -> int: + """Token id _CharTokenizer assigns to its 1-char EOS.""" + return ord("\x1f") % 251 + 1 + + +def test_pi052_text_tokenizer_supervises_eos_at_target_end(): + """The appended EOS is the last supervised label on a target turn — + that's the signal that teaches the LM head to stop. The trailing + newline right after it stays unsupervised (-100).""" + step = PI052TextTokenizerStep(max_length=64) + step._tokenizer = _CharTokenizer() + transition = { + TransitionKey.OBSERVATION: {}, + TransitionKey.COMPLEMENTARY_DATA: { + "messages": [ + {"role": "user", "content": "pick cube"}, + {"role": "assistant", "content": "move to cube"}, + ], + "target_message_indices": [1], + "message_streams": ["high_level", "high_level"], + "index": torch.tensor(10), + }, + } + out = step(transition) + ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0] + labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0] + + supervised = (labels != -100).nonzero().flatten().tolist() + assert supervised, "target turn produced no supervised labels" + last = supervised[-1] + # The last supervised token is the appended EOS. + assert int(ids[last]) == _eos_char_id() + assert int(labels[last]) == _eos_char_id() + # The token right after the EOS (the trailing newline) is NOT supervised. + assert int(labels[last + 1]) == -100 + + class _CharTokenizer: pad_token_id = 0 + eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing def __call__( self,