diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py
index 559e51849..6ebd54168 100644
--- a/src/lerobot/policies/pi052/text_processor_pi052.py
+++ b/src/lerobot/policies/pi052/text_processor_pi052.py
@@ -234,21 +234,33 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
return [int(value)] * batch_size
-def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
+def _format_messages(
+ messages: list[dict[str, Any]],
+ target_indices: list[int] | None = None,
+ eos_token: str | None = None,
+) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt.
+ When both ``target_indices`` and ``eos_token`` are given, the EOS
+ string is appended to each supervised target turn's content and the
+ returned span covers it — so the label builder marks the EOS token
+ as a supervised label. That teaches the LM head where the answer
+ *ends*: without an EOS in the target span the model is never given a
+ stop signal and rambles to ``max_length`` at inference. Inference
+ callers omit both args (no EOS baked into the prompt — the model
+ generates it and ``select_message`` stops on it).
+
Returns:
prompt: the full text the tokenizer will consume.
msg_spans: list of ``(char_start, char_end)`` covering each
- message's content within ``prompt``. The
- target-mask builder uses this to find the
- character ranges belonging to the supervised
- messages.
+ message's supervised payload (content, plus the
+ appended EOS for target turns) within ``prompt``.
"""
+ targets = set(target_indices or [])
parts: list[str] = []
spans: list[tuple[int, int]] = []
cursor = 0
- for m in messages:
+ for i, m in enumerate(messages):
role = m.get("role", "user")
content = m.get("content", "") or ""
# Role tag + newline. The model has to learn to emit the same
@@ -256,11 +268,15 @@ def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[in
# decoding because the chat template is implicit in the
# supervised target span.
header = f"{role.capitalize()}: "
- # span covers ONLY the content portion (so labels are computed
- # over the supervised payload, not the role tag).
- full = header + content + "\n"
+ # A supervised target turn ends with EOS so the model learns to
+ # terminate; the span below covers content + EOS. Non-target
+ # turns (and inference) carry no EOS.
+ body = content + eos_token if (eos_token and i in targets) else content
+ # span covers the content (+ EOS) portion only — never the role
+ # tag — so labels are computed over the supervised payload.
+ full = header + body + "\n"
start = cursor + len(header)
- end = start + len(content)
+ end = start + len(body)
parts.append(full)
spans.append((start, end))
cursor += len(full)
@@ -416,7 +432,11 @@ class PI052TextTokenizerStep(ProcessorStep):
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
- prompt, spans = _format_messages(messages)
+ # Append EOS to supervised target turns so the LM head learns to
+ # stop (the span covers it → it becomes a supervised label).
+ prompt, spans = _format_messages(
+ messages, target_indices, getattr(tokenizer, "eos_token", None)
+ )
encoded = tokenizer(
prompt,
diff --git a/tests/policies/pi052/test_pi052_text_processor.py b/tests/policies/pi052/test_pi052_text_processor.py
index 9547c2a20..77695e12e 100644
--- a/tests/policies/pi052/test_pi052_text_processor.py
+++ b/tests/policies/pi052/test_pi052_text_processor.py
@@ -14,16 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
+"""Tests for PI052's text tokenizer.
-PaliGemma's flat prompt has no structured tool calls, so an assistant
-``say`` tool call must be serialized into a ``...`` text
-marker — otherwise the spoken reply is dropped and never supervised.
+Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
+structured tool calls, so a ``say`` call must be serialized into a
+``...`` text marker) and EOS-termination supervision (the
+supervised target span must end with an EOS token so the LM head learns
+to stop instead of rambling to ``max_length`` at inference).
"""
import torch
-from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
+from lerobot.policies.pi052.text_processor_pi052 import (
+ PI052TextTokenizerStep,
+ _flatten_say_tool_calls,
+ _format_messages,
+)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
@@ -64,8 +70,71 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
assert "tool_calls" not in out
+# ---------------------------------------------------------------------------
+# EOS-termination supervision
+# ---------------------------------------------------------------------------
+
+
+def test_format_messages_appends_eos_to_target_turns_only():
+ msgs = [
+ {"role": "user", "content": "pick cube"},
+ {"role": "assistant", "content": "move to cube"},
+ ]
+ prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="")
+ # EOS is appended to the supervised target (assistant) turn only.
+ assert prompt == "User: pick cube\nAssistant: move to cube\n"
+ # The user span is unchanged; the target span covers content + EOS.
+ assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
+ assert prompt[spans[1][0] : spans[1][1]] == "move to cube"
+
+
+def test_format_messages_without_eos_args_is_unchanged():
+ """Inference callers omit target_indices / eos_token — no EOS baked in."""
+ prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
+ assert prompt == "User: hi\n"
+ assert prompt[spans[0][0] : spans[0][1]] == "hi"
+
+
+def _eos_char_id() -> int:
+ """Token id _CharTokenizer assigns to its 1-char EOS."""
+ return ord("\x1f") % 251 + 1
+
+
+def test_pi052_text_tokenizer_supervises_eos_at_target_end():
+ """The appended EOS is the last supervised label on a target turn —
+ that's the signal that teaches the LM head to stop. The trailing
+ newline right after it stays unsupervised (-100)."""
+ step = PI052TextTokenizerStep(max_length=64)
+ step._tokenizer = _CharTokenizer()
+ transition = {
+ TransitionKey.OBSERVATION: {},
+ TransitionKey.COMPLEMENTARY_DATA: {
+ "messages": [
+ {"role": "user", "content": "pick cube"},
+ {"role": "assistant", "content": "move to cube"},
+ ],
+ "target_message_indices": [1],
+ "message_streams": ["high_level", "high_level"],
+ "index": torch.tensor(10),
+ },
+ }
+ out = step(transition)
+ ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
+ labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
+
+ supervised = (labels != -100).nonzero().flatten().tolist()
+ assert supervised, "target turn produced no supervised labels"
+ last = supervised[-1]
+ # The last supervised token is the appended EOS.
+ assert int(ids[last]) == _eos_char_id()
+ assert int(labels[last]) == _eos_char_id()
+ # The token right after the EOS (the trailing newline) is NOT supervised.
+ assert int(labels[last + 1]) == -100
+
+
class _CharTokenizer:
pad_token_id = 0
+ eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
def __call__(
self,