diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py
index 32bb46810..b1197d787 100644
--- a/src/lerobot/policies/pi052/configuration_pi052.py
+++ b/src/lerobot/policies/pi052/configuration_pi052.py
@@ -49,11 +49,9 @@ class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head
- config. Same recipe-driven training surface; the only differences
- are which backbone the policy uses (PaliGemma here vs SmolVLM2
- there) and the default loss-weight scale (paper §IV.D uses
- ``α=10`` between the two heads, which we encode as
- ``flow_loss_weight=10, text_loss_weight=1``).
+ config. Same recipe-driven training surface; the only difference is
+ which backbone the policy uses (PaliGemma here vs SmolVLM2 there).
+ The flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
"""
# Recipe / language stack ---------------------------------------------
@@ -72,16 +70,20 @@ class PI052Config(PI05Config):
samples text auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
- # Paper §IV.D: total = H(text) + α * MSE(flow), α = 10. We split
- # the same total into two configurable knobs so individual scaling
- # is recoverable.
+ # Paper §IV.D uses α=10 between the flow and text terms, assuming
+ # text is a rare auxiliary task. With the recipe stack the flow-only
+ # `low_level` branch fires on a large share of samples, so α=10
+ # swamps the LM head and collapses generation into degenerate
+ # repetition. We use the milder 5:1 split (matches SmolVLA2Config).
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
- flow_loss_weight: float = 10.0
- """Weight on the action-expert flow-matching term. Default ``10.0``
- matches the paper's α."""
+ flow_loss_weight: float = 5.0
+ """Weight on the action-expert flow-matching term. ``5.0`` — a milder
+ flow:text split than the paper's α=10, since the flow-only
+ ``low_level`` recipe already gives the action expert frequent
+ gradient. Lower it further if the LM head still underfits."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py
index 34b07168a..9069b05ca 100644
--- a/src/lerobot/policies/pi052/modeling_pi052.py
+++ b/src/lerobot/policies/pi052/modeling_pi052.py
@@ -86,6 +86,41 @@ def _shifted_ce(logits: Tensor, labels: Tensor) -> Tensor:
return loss / valid.sum().clamp(min=1)
+def _mark_target_span_causal(
+ prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
+) -> Tensor:
+ """Make the supervised text-target span causally masked.
+
+ ``embed_prefix`` lays the PaliGemma prefix out as ``[images,
+ language]`` with the language block flagged ``att=0`` — which
+ ``make_att_2d_masks`` turns into one fully *bidirectional* block.
+ A supervised target token's hidden state then attends to the very
+ tokens it is trained to predict, so the text cross-entropy
+ degenerates into a copy task (loss → ~0) and the LM head never
+ learns causal next-token prediction. At inference ``select_message``
+ decodes autoregressively (causally) and the head collapses to
+ repeated/garbage tokens.
+
+ Fix: set ``att=1`` on the language positions that are supervised
+ targets (``text_labels != -100``). Under ``make_att_2d_masks``'s
+ cumulative-block rule each target token then attends bidirectionally
+ to images + the user prompt and causally to *earlier* targets only —
+ genuine next-token prediction, matching inference. Non-target
+ language (the user prompt, the flow-only ``low_level`` subtask) stays
+ ``att=0`` / bidirectional. The action expert / FAST tokens are
+ unaffected: they sit at a strictly higher cumsum and still attend to
+ every prefix token.
+ """
+ att = prefix_att_masks.clone()
+ n = min(text_labels.shape[1], lang_end - lang_start)
+ if n <= 0:
+ return att
+ target = text_labels[:, :n] != -100 # (B, n) bool
+ seg = att[:, lang_start : lang_start + n].bool()
+ att[:, lang_start : lang_start + n] = seg | target
+ return att
+
+
def _fast_ce(
fast_logits: Tensor,
action_tokens: Tensor,
@@ -519,6 +554,16 @@ class PI052Policy(PI05Policy):
)
non_fast_prefix_len = prefix_embs.shape[1] # images + language only
+ # Causal-mask the supervised text-target span so the text-CE is
+ # genuine next-token prediction, not a bidirectional copy task
+ # (see ``_mark_target_span_causal``).
+ if text_labels is not None:
+ lang_start = non_fast_prefix_len - text_labels.shape[1]
+ if lang_start >= 0:
+ prefix_att = _mark_target_span_causal(
+ prefix_att, text_labels, lang_start, non_fast_prefix_len
+ )
+
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
@@ -640,6 +685,16 @@ class PI052Policy(PI05Policy):
images, img_masks, lang_tokens, lang_masks
)
+ # Causal-mask the supervised text-target span (see
+ # ``_mark_target_span_causal``) before the FAST tokens are
+ # appended — same fix as ``_compute_all_losses_fused``.
+ if text_labels is not None:
+ lang_start = prefix_embs.shape[1] - text_labels.shape[1]
+ if lang_start >= 0:
+ prefix_att = _mark_target_span_causal(
+ prefix_att, text_labels, lang_start, prefix_embs.shape[1]
+ )
+
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py
index c957cf590..38a4d082e 100644
--- a/src/lerobot/policies/pi052/text_processor_pi052.py
+++ b/src/lerobot/policies/pi052/text_processor_pi052.py
@@ -125,6 +125,62 @@ def _dump_recipe_sample(
print("==============================================\n", flush=True)
+def _content_to_text(content: Any) -> str:
+ """Collapse a message's ``content`` (string or multimodal blocks) to text."""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts = [
+ b["text"]
+ for b in content
+ if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
+ ]
+ return "\n".join(parts)
+ return ""
+
+
+def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
+ """Serialize assistant ``say`` tool calls into a ``...`` marker.
+
+ PaliGemma's flat text prompt has no notion of structured tool calls,
+ and ``_format_messages`` only reads ``role`` / ``content`` — so
+ without this a ``say`` tool call is dropped entirely and never
+ supervised. Rewriting it into the content text as a ``...``
+ marker lets the LM head learn to emit it; the runtime parses it back
+ via ``_split_plan_and_say``. Messages without ``say`` tool calls are
+ returned unchanged (the structured calls, if any, are still dropped).
+ """
+ tool_calls = message.get("tool_calls")
+ if not tool_calls:
+ return message
+ say_texts: list[str] = []
+ for call in tool_calls:
+ if not isinstance(call, dict):
+ continue
+ fn = call.get("function") or {}
+ if fn.get("name") != "say":
+ continue
+ args = fn.get("arguments")
+ if isinstance(args, str):
+ try:
+ import json # noqa: PLC0415
+
+ args = json.loads(args)
+ except (ValueError, TypeError):
+ args = {}
+ text = args.get("text", "") if isinstance(args, dict) else ""
+ if text:
+ say_texts.append(str(text))
+ new = dict(message)
+ new.pop("tool_calls", None)
+ if not say_texts:
+ return new
+ base = _content_to_text(new.get("content")).strip()
+ marker = "".join(f"{t}" for t in say_texts)
+ new["content"] = f"{base}\n{marker}" if base else marker
+ return new
+
+
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
@@ -253,7 +309,10 @@ class PI052TextTokenizerStep(ProcessorStep):
complementary,
)
- messages = [_strip_blocks(m) for m in messages]
+ # Flatten ``say`` tool calls into ``...`` text before
+ # stripping, so the spoken reply is actually tokenized and
+ # supervised (PaliGemma's flat prompt has no structured calls).
+ messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
prompt, spans = _format_messages(messages)
tokenizer = self._ensure_tokenizer()
diff --git a/tests/policies/pi052/test_pi052_attention_masking.py b/tests/policies/pi052/test_pi052_attention_masking.py
new file mode 100644
index 000000000..96ff4b479
--- /dev/null
+++ b/tests/policies/pi052/test_pi052_attention_masking.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python
+
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Attention-masking tests for the PI052 (π0.5 v2) text head.
+
+Regression coverage for the text-CE collapse bug: PaliGemma's
+``embed_prefix`` flags every language token ``att=0``, which
+``make_att_2d_masks`` turns into one fully *bidirectional* block. Under
+that mask the text cross-entropy degenerates into a copy task — a
+supervised target token attends to the tokens it is trained to predict —
+and the LM head never learns causal generation, so ``select_message``
+collapses at inference.
+
+``_mark_target_span_causal`` sets ``att=1`` on the supervised target
+language positions so each target token attends causally among the
+targets while staying bidirectional to images + the user prompt. These
+tests pin that behaviour for the PaliGemma prefix layout.
+"""
+
+import pytest
+import torch
+
+# modeling_pi052 / modeling_pi05 import transformers transitively.
+pytest.importorskip("transformers")
+
+from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
+from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal # noqa: E402
+
+# ---------------------------------------------------------------------------
+# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
+#
+# indices 0-1 : 2 image tokens (att = 0)
+# indices 2-4 : 3 user-prompt lang (att = 0)
+# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
+#
+# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
+# real ids on the 4-token target span. PaliGemma's prefix has no state
+# token (unlike SmolVLA), so the lang span ends at the prefix end.
+# ---------------------------------------------------------------------------
+N_IMAGE = 2
+N_PROMPT = 3
+N_TARGET = 4
+LANG_START = N_IMAGE
+LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length
+PREFIX_LEN = LANG_END
+
+
+def _embed_prefix_att_masks() -> torch.Tensor:
+ """Mimic PaliGemma ``embed_prefix``: images + lang all att=0."""
+ return torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
+
+
+def _text_labels() -> torch.Tensor:
+ """-100 over the prompt span, real ids over the target span."""
+ labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
+ labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
+ return labels
+
+
+def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
+ """2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
+ pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
+ return make_att_2d_masks(pad, prefix_att_masks)[0]
+
+
+def test_mark_sets_att_on_targets_only():
+ """Only the supervised target language positions flip to att=1."""
+ marked = _mark_target_span_causal(
+ _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
+ )
+ expected = [False] * PREFIX_LEN
+ for i in range(LANG_START + N_PROMPT, LANG_END): # target span
+ expected[i] = True
+ assert marked[0].tolist() == expected
+
+
+def test_target_tokens_attend_causally_among_themselves():
+ """A target token must NOT attend to later targets, but must attend
+ to earlier ones — genuine causal next-token prediction."""
+ marked = _mark_target_span_causal(
+ _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
+ )
+ attends = _attends(marked)
+ tgt = range(LANG_START + N_PROMPT, LANG_END)
+ for i in tgt:
+ for j in tgt:
+ if j > i:
+ assert not attends[i, j], f"target {i} must not see future target {j}"
+ else:
+ assert attends[i, j], f"target {i} must see earlier/self target {j}"
+
+
+def test_target_tokens_attend_prompt_and_images_bidirectionally():
+ """Targets keep full visibility of images + the user prompt."""
+ marked = _mark_target_span_causal(
+ _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
+ )
+ attends = _attends(marked)
+ context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
+ for i in range(LANG_START + N_PROMPT, LANG_END):
+ for j in context:
+ assert attends[i, j], f"target {i} must attend context {j}"
+
+
+def test_non_target_subtask_stays_bidirectional():
+ """A flow-only / non-target language span (all -100 labels) leaves the
+ mask untouched — the action expert reads it bidirectionally."""
+ all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
+ marked = _mark_target_span_causal(
+ _embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
+ )
+ assert torch.equal(marked, _embed_prefix_att_masks())
+
+
+def test_unmarked_mask_is_bidirectional_the_bug():
+ """Documents the bug the fix prevents: without ``_mark_target_span_causal``
+ a target token attends *bidirectionally* to later targets — the
+ text-CE can copy the answer it is trained to predict."""
+ attends = _attends(_embed_prefix_att_masks())
+ first_tgt = LANG_START + N_PROMPT
+ last_tgt = LANG_END - 1
+ assert attends[first_tgt, last_tgt], (
+ "raw embed_prefix mask is bidirectional over language — the first "
+ "target token can see the last, which is the collapse bug"
+ )
diff --git a/tests/policies/pi052/test_pi052_text_processor.py b/tests/policies/pi052/test_pi052_text_processor.py
new file mode 100644
index 000000000..918582845
--- /dev/null
+++ b/tests/policies/pi052/test_pi052_text_processor.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for PI052's text-tokenizer ``say`` tool-call flattening.
+
+PaliGemma's flat prompt has no structured tool calls, so an assistant
+``say`` tool call must be serialized into a ``...`` text
+marker — otherwise the spoken reply is dropped and never supervised.
+"""
+
+from lerobot.policies.pi052.text_processor_pi052 import _flatten_say_tool_calls
+
+
+def _say_call(text):
+ return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
+
+
+def test_flatten_appends_say_marker_and_drops_tool_calls():
+ msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
+ out = _flatten_say_tool_calls(msg)
+ assert "tool_calls" not in out
+ assert out["content"] == "Heading to the cube.\nOn it!"
+
+
+def test_flatten_marker_only_when_content_empty_or_none():
+ out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
+ assert out["content"] == "hi"
+
+
+def test_flatten_accepts_json_string_arguments():
+ call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
+ out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
+ assert out["content"] == "p\nhello there"
+
+
+def test_flatten_leaves_messages_without_tool_calls_untouched():
+ msg = {"role": "assistant", "content": "just a plan"}
+ assert _flatten_say_tool_calls(msg) == msg
+
+
+def test_flatten_drops_non_say_tool_calls_but_keeps_content():
+ weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
+ out = _flatten_say_tool_calls(
+ {"role": "assistant", "content": "plan only", "tool_calls": [weather]}
+ )
+ assert out["content"] == "plan only"
+ assert "tool_calls" not in out