diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 32bb46810..b1197d787 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -49,11 +49,9 @@ class PI052Config(PI05Config): """π0.5 with the PaliGemma LM head re-enabled for subtask prediction. See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head - config. Same recipe-driven training surface; the only differences - are which backbone the policy uses (PaliGemma here vs SmolVLM2 - there) and the default loss-weight scale (paper §IV.D uses - ``α=10`` between the two heads, which we encode as - ``flow_loss_weight=10, text_loss_weight=1``). + config. Same recipe-driven training surface; the only difference is + which backbone the policy uses (PaliGemma here vs SmolVLM2 there). + The flow:text loss split is the milder 5:1 (see ``flow_loss_weight``). """ # Recipe / language stack --------------------------------------------- @@ -72,16 +70,20 @@ class PI052Config(PI05Config): samples text auto-regressively after the prefix.""" # Loss weights -------------------------------------------------------- - # Paper §IV.D: total = H(text) + α * MSE(flow), α = 10. We split - # the same total into two configurable knobs so individual scaling - # is recoverable. + # Paper §IV.D uses α=10 between the flow and text terms, assuming + # text is a rare auxiliary task. With the recipe stack the flow-only + # `low_level` branch fires on a large share of samples, so α=10 + # swamps the LM head and collapses generation into degenerate + # repetition. We use the milder 5:1 split (matches SmolVLA2Config). text_loss_weight: float = 1.0 """Weight on the LM-head cross-entropy term. Set to ``0`` to disable text training entirely (reverts to flow-only / π0.5 behaviour).""" - flow_loss_weight: float = 10.0 - """Weight on the action-expert flow-matching term. Default ``10.0`` - matches the paper's α.""" + flow_loss_weight: float = 5.0 + """Weight on the action-expert flow-matching term. ``5.0`` — a milder + flow:text split than the paper's α=10, since the flow-only + ``low_level`` recipe already gives the action expert frequent + gradient. Lower it further if the LM head still underfits.""" # Backbone training --------------------------------------------------- unfreeze_lm_head: bool = True diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 34b07168a..9069b05ca 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -86,6 +86,41 @@ def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor: return loss / valid.sum().clamp(min=1) +def _mark_target_span_causal( + prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int +) -> Tensor: + """Make the supervised text-target span causally masked. + + ``embed_prefix`` lays the PaliGemma prefix out as ``[images, + language]`` with the language block flagged ``att=0`` — which + ``make_att_2d_masks`` turns into one fully *bidirectional* block. + A supervised target token's hidden state then attends to the very + tokens it is trained to predict, so the text cross-entropy + degenerates into a copy task (loss → ~0) and the LM head never + learns causal next-token prediction. At inference ``select_message`` + decodes autoregressively (causally) and the head collapses to + repeated/garbage tokens. + + Fix: set ``att=1`` on the language positions that are supervised + targets (``text_labels != -100``). Under ``make_att_2d_masks``'s + cumulative-block rule each target token then attends bidirectionally + to images + the user prompt and causally to *earlier* targets only — + genuine next-token prediction, matching inference. Non-target + language (the user prompt, the flow-only ``low_level`` subtask) stays + ``att=0`` / bidirectional. The action expert / FAST tokens are + unaffected: they sit at a strictly higher cumsum and still attend to + every prefix token. + """ + att = prefix_att_masks.clone() + n = min(text_labels.shape[1], lang_end - lang_start) + if n <= 0: + return att + target = text_labels[:, :n] != -100 # (B, n) bool + seg = att[:, lang_start : lang_start + n].bool() + att[:, lang_start : lang_start + n] = seg | target + return att + + def _fast_ce( fast_logits: Tensor, action_tokens: Tensor, @@ -519,6 +554,16 @@ class PI052Policy(PI05Policy): ) non_fast_prefix_len = prefix_embs.shape[1] # images + language only + # Causal-mask the supervised text-target span so the text-CE is + # genuine next-token prediction, not a bidirectional copy task + # (see ``_mark_target_span_causal``). + if text_labels is not None: + lang_start = non_fast_prefix_len - text_labels.shape[1] + if lang_start >= 0: + prefix_att = _mark_target_span_causal( + prefix_att, text_labels, lang_start, non_fast_prefix_len + ) + fast_len = 0 if action_tokens is not None and action_mask is not None: emb_dim = prefix_embs.shape[-1] @@ -640,6 +685,16 @@ class PI052Policy(PI05Policy): images, img_masks, lang_tokens, lang_masks ) + # Causal-mask the supervised text-target span (see + # ``_mark_target_span_causal``) before the FAST tokens are + # appended — same fix as ``_compute_all_losses_fused``. + if text_labels is not None: + lang_start = prefix_embs.shape[1] - text_labels.shape[1] + if lang_start >= 0: + prefix_att = _mark_target_span_causal( + prefix_att, text_labels, lang_start, prefix_embs.shape[1] + ) + fast_len = 0 if action_tokens is not None and action_mask is not None: emb_dim = prefix_embs.shape[-1] diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index c957cf590..38a4d082e 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -125,6 +125,62 @@ def _dump_recipe_sample( print("==============================================\n", flush=True) +def _content_to_text(content: Any) -> str: + """Collapse a message's ``content`` (string or multimodal blocks) to text.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [ + b["text"] + for b in content + if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str) + ] + return "\n".join(parts) + return "" + + +def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]: + """Serialize assistant ``say`` tool calls into a ``...`` marker. + + PaliGemma's flat text prompt has no notion of structured tool calls, + and ``_format_messages`` only reads ``role`` / ``content`` — so + without this a ``say`` tool call is dropped entirely and never + supervised. Rewriting it into the content text as a ``...`` + marker lets the LM head learn to emit it; the runtime parses it back + via ``_split_plan_and_say``. Messages without ``say`` tool calls are + returned unchanged (the structured calls, if any, are still dropped). + """ + tool_calls = message.get("tool_calls") + if not tool_calls: + return message + say_texts: list[str] = [] + for call in tool_calls: + if not isinstance(call, dict): + continue + fn = call.get("function") or {} + if fn.get("name") != "say": + continue + args = fn.get("arguments") + if isinstance(args, str): + try: + import json # noqa: PLC0415 + + args = json.loads(args) + except (ValueError, TypeError): + args = {} + text = args.get("text", "") if isinstance(args, dict) else "" + if text: + say_texts.append(str(text)) + new = dict(message) + new.pop("tool_calls", None) + if not say_texts: + return new + base = _content_to_text(new.get("content")).strip() + marker = "".join(f"{t}" for t in say_texts) + new["content"] = f"{base}\n{marker}" if base else marker + return new + + def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]: """Normalise a message's content to a plain string. @@ -253,7 +309,10 @@ class PI052TextTokenizerStep(ProcessorStep): complementary, ) - messages = [_strip_blocks(m) for m in messages] + # Flatten ``say`` tool calls into ``...`` text before + # stripping, so the spoken reply is actually tokenized and + # supervised (PaliGemma's flat prompt has no structured calls). + messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages] prompt, spans = _format_messages(messages) tokenizer = self._ensure_tokenizer() diff --git a/tests/policies/pi052/test_pi052_attention_masking.py b/tests/policies/pi052/test_pi052_attention_masking.py new file mode 100644 index 000000000..96ff4b479 --- /dev/null +++ b/tests/policies/pi052/test_pi052_attention_masking.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attention-masking tests for the PI052 (π0.5 v2) text head. + +Regression coverage for the text-CE collapse bug: PaliGemma's +``embed_prefix`` flags every language token ``att=0``, which +``make_att_2d_masks`` turns into one fully *bidirectional* block. Under +that mask the text cross-entropy degenerates into a copy task — a +supervised target token attends to the tokens it is trained to predict — +and the LM head never learns causal generation, so ``select_message`` +collapses at inference. + +``_mark_target_span_causal`` sets ``att=1`` on the supervised target +language positions so each target token attends causally among the +targets while staying bidirectional to images + the user prompt. These +tests pin that behaviour for the PaliGemma prefix layout. +""" + +import pytest +import torch + +# modeling_pi052 / modeling_pi05 import transformers transitively. +pytest.importorskip("transformers") + +from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402 +from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402 + +# --------------------------------------------------------------------------- +# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang] +# +# indices 0-1 : 2 image tokens (att = 0) +# indices 2-4 : 3 user-prompt lang (att = 0) +# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix) +# +# ``text_labels`` covers the 7 language tokens; -100 on the prompt span, +# real ids on the 4-token target span. PaliGemma's prefix has no state +# token (unlike SmolVLA), so the lang span ends at the prefix end. +# --------------------------------------------------------------------------- +N_IMAGE = 2 +N_PROMPT = 3 +N_TARGET = 4 +LANG_START = N_IMAGE +LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length +PREFIX_LEN = LANG_END + + +def _embed_prefix_att_masks() -> torch.Tensor: + """Mimic PaliGemma ``embed_prefix``: images + lang all att=0.""" + return torch.zeros(1, PREFIX_LEN, dtype=torch.bool) + + +def _text_labels() -> torch.Tensor: + """-100 over the prompt span, real ids over the target span.""" + labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long) + labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET) + return labels + + +def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor: + """2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j.""" + pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool) + return make_att_2d_masks(pad, prefix_att_masks)[0] + + +def test_mark_sets_att_on_targets_only(): + """Only the supervised target language positions flip to att=1.""" + marked = _mark_target_span_causal( + _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END + ) + expected = [False] * PREFIX_LEN + for i in range(LANG_START + N_PROMPT, LANG_END): # target span + expected[i] = True + assert marked[0].tolist() == expected + + +def test_target_tokens_attend_causally_among_themselves(): + """A target token must NOT attend to later targets, but must attend + to earlier ones — genuine causal next-token prediction.""" + marked = _mark_target_span_causal( + _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END + ) + attends = _attends(marked) + tgt = range(LANG_START + N_PROMPT, LANG_END) + for i in tgt: + for j in tgt: + if j > i: + assert not attends[i, j], f"target {i} must not see future target {j}" + else: + assert attends[i, j], f"target {i} must see earlier/self target {j}" + + +def test_target_tokens_attend_prompt_and_images_bidirectionally(): + """Targets keep full visibility of images + the user prompt.""" + marked = _mark_target_span_causal( + _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END + ) + attends = _attends(marked) + context = list(range(0, LANG_START + N_PROMPT)) # images + prompt + for i in range(LANG_START + N_PROMPT, LANG_END): + for j in context: + assert attends[i, j], f"target {i} must attend context {j}" + + +def test_non_target_subtask_stays_bidirectional(): + """A flow-only / non-target language span (all -100 labels) leaves the + mask untouched — the action expert reads it bidirectionally.""" + all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long) + marked = _mark_target_span_causal( + _embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END + ) + assert torch.equal(marked, _embed_prefix_att_masks()) + + +def test_unmarked_mask_is_bidirectional_the_bug(): + """Documents the bug the fix prevents: without ``_mark_target_span_causal`` + a target token attends *bidirectionally* to later targets — the + text-CE can copy the answer it is trained to predict.""" + attends = _attends(_embed_prefix_att_masks()) + first_tgt = LANG_START + N_PROMPT + last_tgt = LANG_END - 1 + assert attends[first_tgt, last_tgt], ( + "raw embed_prefix mask is bidirectional over language — the first " + "target token can see the last, which is the collapse bug" + ) diff --git a/tests/policies/pi052/test_pi052_text_processor.py b/tests/policies/pi052/test_pi052_text_processor.py new file mode 100644 index 000000000..918582845 --- /dev/null +++ b/tests/policies/pi052/test_pi052_text_processor.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PI052's text-tokenizer ``say`` tool-call flattening. + +PaliGemma's flat prompt has no structured tool calls, so an assistant +``say`` tool call must be serialized into a ``...`` text +marker — otherwise the spoken reply is dropped and never supervised. +""" + +from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls + + +def _say_call(text): + return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}} + + +def test_flatten_appends_say_marker_and_drops_tool_calls(): + msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]} + out = _flatten_say_tool_calls(msg) + assert "tool_calls" not in out + assert out["content"] == "Heading to the cube.\nOn it!" + + +def test_flatten_marker_only_when_content_empty_or_none(): + out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]}) + assert out["content"] == "hi" + + +def test_flatten_accepts_json_string_arguments(): + call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}} + out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]}) + assert out["content"] == "p\nhello there" + + +def test_flatten_leaves_messages_without_tool_calls_untouched(): + msg = {"role": "assistant", "content": "just a plan"} + assert _flatten_say_tool_calls(msg) == msg + + +def test_flatten_drops_non_say_tool_calls_but_keeps_content(): + weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}} + out = _flatten_say_tool_calls( + {"role": "assistant", "content": "plan only", "tool_calls": [weather]} + ) + assert out["content"] == "plan only" + assert "tool_calls" not in out