mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fix(pi052): supervise an EOS token at the end of each text target
PI052TextTokenizerStep masked text_labels over the assistant turn's *content only* — the trailing newline was excluded and no EOS token was ever a supervised label. So the LM head was never given a stop signal: at inference select_message decoded to max_new_tokens, producing the runaway subtask paragraphs and the "}"}"}-style VQA tails. _format_messages now appends the tokenizer's EOS to each supervised target turn and extends that turn's span to cover it, so the EOS lands in text_labels. _shifted_ce then trains "<last content token> -> EOS" and the model learns to terminate; select_message stops on it. Inference callers (the runtime's _build_text_batch_pi052) pass no target_indices / eos_token, so no EOS is baked into the prompt — the model generates it. Verified end-to-end with the PaliGemma tokenizer: the supervised span is `<content><eos>` and the trailing newline stays unsupervised. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -234,21 +234,33 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
|||||||
return [int(value)] * batch_size
|
return [int(value)] * batch_size
|
||||||
|
|
||||||
|
|
||||||
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
|
def _format_messages(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
target_indices: list[int] | None = None,
|
||||||
|
eos_token: str | None = None,
|
||||||
|
) -> tuple[str, list[tuple[int, int]]]:
|
||||||
"""Concatenate messages into the π0.5-style flat prompt.
|
"""Concatenate messages into the π0.5-style flat prompt.
|
||||||
|
|
||||||
|
When both ``target_indices`` and ``eos_token`` are given, the EOS
|
||||||
|
string is appended to each supervised target turn's content and the
|
||||||
|
returned span covers it — so the label builder marks the EOS token
|
||||||
|
as a supervised label. That teaches the LM head where the answer
|
||||||
|
*ends*: without an EOS in the target span the model is never given a
|
||||||
|
stop signal and rambles to ``max_length`` at inference. Inference
|
||||||
|
callers omit both args (no EOS baked into the prompt — the model
|
||||||
|
generates it and ``select_message`` stops on it).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
prompt: the full text the tokenizer will consume.
|
prompt: the full text the tokenizer will consume.
|
||||||
msg_spans: list of ``(char_start, char_end)`` covering each
|
msg_spans: list of ``(char_start, char_end)`` covering each
|
||||||
message's content within ``prompt``. The
|
message's supervised payload (content, plus the
|
||||||
target-mask builder uses this to find the
|
appended EOS for target turns) within ``prompt``.
|
||||||
character ranges belonging to the supervised
|
|
||||||
messages.
|
|
||||||
"""
|
"""
|
||||||
|
targets = set(target_indices or [])
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
spans: list[tuple[int, int]] = []
|
spans: list[tuple[int, int]] = []
|
||||||
cursor = 0
|
cursor = 0
|
||||||
for m in messages:
|
for i, m in enumerate(messages):
|
||||||
role = m.get("role", "user")
|
role = m.get("role", "user")
|
||||||
content = m.get("content", "") or ""
|
content = m.get("content", "") or ""
|
||||||
# Role tag + newline. The model has to learn to emit the same
|
# Role tag + newline. The model has to learn to emit the same
|
||||||
@@ -256,11 +268,15 @@ def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[in
|
|||||||
# decoding because the chat template is implicit in the
|
# decoding because the chat template is implicit in the
|
||||||
# supervised target span.
|
# supervised target span.
|
||||||
header = f"{role.capitalize()}: "
|
header = f"{role.capitalize()}: "
|
||||||
# span covers ONLY the content portion (so labels are computed
|
# A supervised target turn ends with EOS so the model learns to
|
||||||
# over the supervised payload, not the role tag).
|
# terminate; the span below covers content + EOS. Non-target
|
||||||
full = header + content + "\n"
|
# turns (and inference) carry no EOS.
|
||||||
|
body = content + eos_token if (eos_token and i in targets) else content
|
||||||
|
# span covers the content (+ EOS) portion only — never the role
|
||||||
|
# tag — so labels are computed over the supervised payload.
|
||||||
|
full = header + body + "\n"
|
||||||
start = cursor + len(header)
|
start = cursor + len(header)
|
||||||
end = start + len(content)
|
end = start + len(body)
|
||||||
parts.append(full)
|
parts.append(full)
|
||||||
spans.append((start, end))
|
spans.append((start, end))
|
||||||
cursor += len(full)
|
cursor += len(full)
|
||||||
@@ -416,7 +432,11 @@ class PI052TextTokenizerStep(ProcessorStep):
|
|||||||
# stripping, so the spoken reply is actually tokenized and
|
# stripping, so the spoken reply is actually tokenized and
|
||||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||||
prompt, spans = _format_messages(messages)
|
# Append EOS to supervised target turns so the LM head learns to
|
||||||
|
# stop (the span covers it → it becomes a supervised label).
|
||||||
|
prompt, spans = _format_messages(
|
||||||
|
messages, target_indices, getattr(tokenizer, "eos_token", None)
|
||||||
|
)
|
||||||
|
|
||||||
encoded = tokenizer(
|
encoded = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
|
|||||||
@@ -14,16 +14,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
|
"""Tests for PI052's text tokenizer.
|
||||||
|
|
||||||
PaliGemma's flat prompt has no structured tool calls, so an assistant
|
Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
|
||||||
``say`` tool call must be serialized into a ``<say>...</say>`` text
|
structured tool calls, so a ``say`` call must be serialized into a
|
||||||
marker — otherwise the spoken reply is dropped and never supervised.
|
``<say>...</say>`` text marker) and EOS-termination supervision (the
|
||||||
|
supervised target span must end with an EOS token so the LM head learns
|
||||||
|
to stop instead of rambling to ``max_length`` at inference).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.policies.pi052.text_processor_pi052 import PI052TextTokenizerStep, _flatten_say_tool_calls
|
from lerobot.policies.pi052.text_processor_pi052 import (
|
||||||
|
PI052TextTokenizerStep,
|
||||||
|
_flatten_say_tool_calls,
|
||||||
|
_format_messages,
|
||||||
|
)
|
||||||
from lerobot.types import TransitionKey
|
from lerobot.types import TransitionKey
|
||||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
@@ -64,8 +70,71 @@ def test_flatten_drops_non_say_tool_calls_but_keeps_content():
|
|||||||
assert "tool_calls" not in out
|
assert "tool_calls" not in out
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EOS-termination supervision
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_messages_appends_eos_to_target_turns_only():
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "pick cube"},
|
||||||
|
{"role": "assistant", "content": "move to cube"},
|
||||||
|
]
|
||||||
|
prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="<eos>")
|
||||||
|
# EOS is appended to the supervised target (assistant) turn only.
|
||||||
|
assert prompt == "User: pick cube\nAssistant: move to cube<eos>\n"
|
||||||
|
# The user span is unchanged; the target span covers content + EOS.
|
||||||
|
assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
|
||||||
|
assert prompt[spans[1][0] : spans[1][1]] == "move to cube<eos>"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_messages_without_eos_args_is_unchanged():
|
||||||
|
"""Inference callers omit target_indices / eos_token — no EOS baked in."""
|
||||||
|
prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
|
||||||
|
assert prompt == "User: hi\n"
|
||||||
|
assert prompt[spans[0][0] : spans[0][1]] == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def _eos_char_id() -> int:
|
||||||
|
"""Token id _CharTokenizer assigns to its 1-char EOS."""
|
||||||
|
return ord("\x1f") % 251 + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_pi052_text_tokenizer_supervises_eos_at_target_end():
|
||||||
|
"""The appended EOS is the last supervised label on a target turn —
|
||||||
|
that's the signal that teaches the LM head to stop. The trailing
|
||||||
|
newline right after it stays unsupervised (-100)."""
|
||||||
|
step = PI052TextTokenizerStep(max_length=64)
|
||||||
|
step._tokenizer = _CharTokenizer()
|
||||||
|
transition = {
|
||||||
|
TransitionKey.OBSERVATION: {},
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA: {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "pick cube"},
|
||||||
|
{"role": "assistant", "content": "move to cube"},
|
||||||
|
],
|
||||||
|
"target_message_indices": [1],
|
||||||
|
"message_streams": ["high_level", "high_level"],
|
||||||
|
"index": torch.tensor(10),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out = step(transition)
|
||||||
|
ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
|
||||||
|
labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
|
||||||
|
|
||||||
|
supervised = (labels != -100).nonzero().flatten().tolist()
|
||||||
|
assert supervised, "target turn produced no supervised labels"
|
||||||
|
last = supervised[-1]
|
||||||
|
# The last supervised token is the appended EOS.
|
||||||
|
assert int(ids[last]) == _eos_char_id()
|
||||||
|
assert int(labels[last]) == _eos_char_id()
|
||||||
|
# The token right after the EOS (the trailing newline) is NOT supervised.
|
||||||
|
assert int(labels[last + 1]) == -100
|
||||||
|
|
||||||
|
|
||||||
class _CharTokenizer:
|
class _CharTokenizer:
|
||||||
pad_token_id = 0
|
pad_token_id = 0
|
||||||
|
eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user