feat(pi052): π0.5 v2 — full reproduction of the π0.5 paper recipe

New ``lerobot.policies.pi052`` (parallel to ``smolvla2``) that adds
text-prediction + hierarchical-inference on top of the existing π0.5
implementation. Mirrors the paper's §IV.D dual-head training:

  L = H(text) + α * ‖ω - a - f_θ_action(...)‖²,  α = 10

Components:

  * ``configuration_pi052.py``     thin PI05Config subclass; adds
                                    recipe_path, text/flow loss weights
                                    (default α=10 per paper), prompt
                                    dropout knobs, ``unfreeze_lm_head``.
  * ``text_processor_pi052.py``    PI052TextTokenizerStep — concatenates
                                    rendered messages as ``Role: ...``
                                    plain text (PaliGemma has no chat
                                    template), tokenises with the
                                    PaliGemma tokenizer, builds a label
                                    mask covering supervised target
                                    spans. Includes Pi 0.7 §V.E
                                    per-component prompt dropout.
  * ``processor_pi052.py``         make_pi052_pre_post_processors —
                                    Rename + Batch + Relative +
                                    Normalize + RenderMessagesStep +
                                    PI052TextTokenizerStep + Device.
                                    Falls back to π0.5's plain pipeline
                                    when recipe_path is unset.
  * ``modeling_pi052.py``          PI052Policy(PI05Policy) — re-enables
                                    PaliGemma ``lm_head``, computes
                                    text_loss via CE on the supervised
                                    span, sums with flow_loss in
                                    forward(), and adds select_message
                                    for AR text generation at inference
                                    (same surface as
                                    SmolVLA2Policy.select_message so
                                    SmolVLA2Runtime drives it unchanged).

Plus the supporting plumbing:

  * recipe ``configs/recipes/pi052_hirobot.yaml`` — same Hi-Robot blend
    as smolvla2_hirobot.yaml, with the same ``${subtask}`` /
    ``if_present`` supervision fix (current span at every frame, not
    ``${next_subtask}``).
  * SLURM ``examples/training/pi052_hirobot.slurm`` — full training
    command matching the SmolVLA2 launcher.
  * factory registration: ``--policy.type=pi052`` resolves to
    PI052Policy with the new processor.

Same multi-rate runtime (``lerobot.policies.smolvla2.inference``)
drives this policy too — both expose ``predict_action_chunk`` for the
action expert and ``select_message`` for the LM head.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 10:59:26 +02:00
parent 1d24301b67
commit ef5879a02a
8 changed files with 1127 additions and 0 deletions
+75
View File
@@ -0,0 +1,75 @@
#!/bin/bash
#SBATCH --job-name=pi052-hirobot
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=48:00:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=8
# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe.
#
# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml),
# just on the PaliGemma 2B + Gemma-300m action-expert backbone the
# paper uses. The text head learns subtask prediction via cross-
# entropy on supervised spans; the action expert learns the flow
# field. Paper §IV.D mixes the two losses with α=10, which we encode
# as flow_loss_weight=10 / text_loss_weight=1.
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}"
JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}"
NUM_PROCESSES="${NUM_PROCESSES:-8}"
BATCH_SIZE="${BATCH_SIZE:-32}"
STEPS="${STEPS:-15000}"
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}"
echo "Training pi052 on $DATASET"
echo " GPUs: $NUM_PROCESSES"
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
echo " steps: $STEPS"
echo " output: $OUTPUT_DIR"
echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1"
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
-m lerobot.scripts.lerobot_train \
--policy.type=pi052 \
--policy.recipe_path=recipes/pi052_hirobot.yaml \
--dataset.repo_id="$DATASET" \
--dataset.revision=main \
--dataset.video_backend=pyav \
--output_dir="$OUTPUT_DIR" \
--job_name="$JOB_NAME" \
--policy.repo_id="$POLICY_REPO_ID" \
--policy.compile_model=false \
--policy.device=cuda \
--policy.tokenizer_max_length=512 \
--policy.text_loss_weight=1.0 \
--policy.flow_loss_weight=10.0 \
--policy.unfreeze_lm_head=true \
--steps="$STEPS" \
--policy.scheduler_decay_steps="$STEPS" \
--batch_size="$BATCH_SIZE" \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=hirobot \
--log_freq=100 \
--save_freq="$STEPS" \
--num_workers=0 \
--dataset.image_transforms.enable=true \
--dataset.image_transforms.max_num_transforms=3 \
--dataset.image_transforms.random_order=true \
--policy.plan_dropout_prob=0.30 \
--policy.memory_dropout_prob=0.30 \
--policy.subtask_dropout_prob=0.20
@@ -0,0 +1,92 @@
# π0.5 v2 — Hi-Robot / MEM / ECoT blend, reproducing the paper's
# hierarchical inference recipe on lerobot.
#
# Architecturally identical blend to ``smolvla2_hirobot.yaml`` — same
# five sub-recipes (memory_update, user_interjection_response,
# high_level_subtask, low_level_execution, ask_vqa_*) with the same
# message layouts. The only difference is which backbone the renderer's
# output is fed into:
#
# * SmolVLA2 calls SmolVLM's chat-template tokenizer
# (``apply_chat_template`` with chat-pretrained role markers).
# * π0.5 v2 concatenates the rendered messages as ``Role: content``
# plain text, since PaliGemma is not chat-pretrained. See
# ``PI052TextTokenizerStep`` in ``policies/pi052/text_processor_pi052.py``.
#
# Same supervision target convention as ``smolvla2_hirobot.yaml``: the
# ``high_level_subtask`` recipe targets ``${subtask}`` (the *current*
# active span at every frame) rather than ``${next_subtask}`` (which
# is empty on stable phases and used to train the model to emit
# newlines).
blend:
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.16
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
# Pi 0.5 / Pi 0.7 supervision: predict the *current* active subtask
# at every frame from task + plan + memory + visual prefix.
# ``if_present: subtask`` skips frames with no active span instead of
# supervising an empty target (the failure mode that produces newline
# collapse).
high_level_subtask:
weight: 0.15
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.35
messages:
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
+19
View File
@@ -128,6 +128,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi052":
from .pi052.modeling_pi052 import PI052Policy
return PI052Policy
elif name == "sac":
from .sac.modeling_sac import SACPolicy
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi052":
from .pi052.configuration_pi052 import PI052Config
return PI052Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -370,6 +378,17 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "pi052":
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
# come before the PI05Config isinstance check below (otherwise
# pi052 would silently pick up π0.5's processor).
from .pi052.processor_pi052 import make_pi052_pre_post_processors
processors = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI05Config):
from .pi05.processor_pi05 import make_pi05_pre_post_processors
+42
View File
@@ -0,0 +1,42 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
inference recipe on lerobot.
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
(the "high-level subtask prediction" of the paper, §IV.D),
* AR text generation at inference (:meth:`PI052Policy.select_message`),
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
text head against missing context at inference.
See ``src/lerobot/configs/recipes/pi052_hirobot.yaml`` for the
canonical training recipe and
``examples/training/pi052_hirobot.slurm`` for the launcher.
"""
from .configuration_pi052 import PI052Config
from .modeling_pi052 import PI052Policy
from .processor_pi052 import make_pi052_pre_post_processors
from .text_processor_pi052 import PI052TextTokenizerStep
__all__ = [
"PI052Config",
"PI052Policy",
"PI052TextTokenizerStep",
"make_pi052_pre_post_processors",
]
@@ -0,0 +1,109 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
hierarchical inference recipe.
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
~300M Gemma action expert, joint training with FAST tokens during
pre-train and flow matching during post-train), but with the
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
to predict both:
* **subtask strings** at the high level (cross-entropy on the LM
head), and
* **action chunks** at the low level (flow matching on the
action-expert tokens).
This is the dual-head co-training pattern from the paper:
L = H(x, f_θ_text) + α * ω - a - f_θ_action(a_τ, o, )²
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
inference into a text-prediction step followed by an action-prediction
step, which mirrors what ``SmolVLA2Runtime`` already does on a
SmolVLM2 backbone.
"""
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from ..pi05.configuration_pi05 import PI05Config
@PreTrainedConfig.register_subclass("pi052")
@dataclass
class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head
config. Same recipe-driven training surface; the only differences
are which backbone the policy uses (PaliGemma here vs SmolVLM2
there) and the default loss-weight scale (paper §IV.D uses
``α=10`` between the two heads, which we encode as
``flow_loss_weight=10, text_loss_weight=1``).
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/pi052_hirobot.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
shipped alongside this policy. Set to ``None`` to disable recipe
rendering and fall back to π0.5's single-task ``Task: ... Action:``
prompt path (unannotated datasets keep working that way)."""
apply_chat_template: bool = False
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
chat template. So unlike SmolVLA2 we don't apply one. The recipe
renderer's output is concatenated as a plain prefix + assistant
suffix instead, mirroring how the π0.5 paper's high-level inference
samples text auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
# Paper §IV.D: total = H(text) + α * MSE(flow), α = 10. We split
# the same total into two configurable knobs so individual scaling
# is recoverable.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
flow_loss_weight: float = 10.0
"""Weight on the action-expert flow-matching term. Default ``10.0``
matches the paper's α."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
The existing ``PI05Policy`` zeroes / freezes the head on load
because it never reads from it. Must be ``True`` for π0.5-style
hierarchical inference."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# Same regulariser surface as SmolVLA2: randomly drop non-target
# context messages so the LM head learns to handle missing /
# stale plan / memory at inference. Defaults to 0.0 so behaviour
# is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
# we're training it. Override the π0.5 default
# (``train_expert_only=True``) unless the user explicitly opts
# out of text training via ``text_loss_weight=0``.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
self.train_expert_only = False
@@ -0,0 +1,339 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 policy — dual-head training & hierarchical inference.
A thin subclass of :class:`PI05Policy` that:
* keeps the PaliGemma ``lm_head`` unfrozen during fine-tuning
(``PI05Policy`` zeroes / freezes it because it never reads from
the head; ``PI052Config.unfreeze_lm_head`` flips that),
* adds a ``text_loss`` term computed via cross-entropy on
``text_labels`` (built by ``PI052TextTokenizerStep``),
* adds :meth:`select_message` for AR text generation at inference
(the high-level step in the π0.5 paper's two-stage inference loop),
* combines both losses in :meth:`forward` per Eq. (1) of the paper:
L = H(x, f_θ_text) + α * ω - a - f_θ_action(...)²
with α controllable via ``config.flow_loss_weight``.
The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see
``lerobot.policies.smolvla2.inference``) can drive this policy too
both expose ``predict_action_chunk`` for the action expert and
``select_message`` for the LM head.
"""
from __future__ import annotations
import logging
import math
from typing import Any
import torch
from torch import Tensor
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from ..pi05.configuration_pi05 import PI05Config
from ..pi05.modeling_pi05 import PI05Policy
from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__)
class PI052Policy(PI05Policy):
"""π0.5 with the PaliGemma LM head re-enabled."""
config_class = PI052Config
name = "pi052"
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
super().__init__(config, **kwargs)
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
# freezes a few terminal layers when ``train_expert_only`` is
# the (default) True. We re-enable the head if the user
# wants text supervision.
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
self._unfreeze_lm_head()
# ------------------------------------------------------------------
# Head unfreeze helper
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Walk the PaliGemma submodules and re-enable gradients on
``lm_head`` + the immediately preceding norm / last text-model
layer that ``PI05Policy`` typically freezes."""
backbone = self.model.paligemma_with_expert.paligemma
if hasattr(backbone, "lm_head"):
for p in backbone.lm_head.parameters():
p.requires_grad_(True)
# The text model's final norm and last transformer block —
# mirror SmolVLA2's logic, which finds these dynamically by
# the trainable=False parameters that point at the head's
# neighbourhood.
text_model = getattr(backbone, "model", None)
text_model = getattr(text_model, "language_model", text_model)
if text_model is None:
return
norm = getattr(text_model, "norm", None)
if norm is not None:
for p in norm.parameters():
p.requires_grad_(True)
layers = getattr(text_model, "layers", None)
if isinstance(layers, (list, torch.nn.ModuleList)) and len(layers) > 0:
for p in layers[-1].parameters():
p.requires_grad_(True)
# ------------------------------------------------------------------
# Forward (dual loss: flow + text)
# ------------------------------------------------------------------
def forward(
self,
batch: dict[str, Tensor],
reduction: str = "mean",
) -> tuple[Tensor, dict]:
"""Dual-head forward: flow-matching loss + text-CE loss.
When ``text_labels`` isn't present in the batch (e.g. the
recipe wasn't applied), we delegate to ``PI05Policy.forward``
unchanged. Otherwise we compute both losses and sum them with
``flow_loss_weight`` / ``text_loss_weight``.
"""
text_labels = batch.get("text_labels")
predict_actions_t = batch.get("predict_actions")
run_flow = (
self.config.flow_loss_weight > 0
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
)
run_text = self.config.text_loss_weight > 0 and text_labels is not None
loss_dict: dict[str, Any] = {}
total: Tensor | None = None
if run_flow:
flow_loss, flow_dict = super().forward(batch, reduction=reduction)
for k, v in flow_dict.items():
loss_dict[f"flow_{k}"] = v
loss_dict["flow_loss"] = (
flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan")
)
total = self.config.flow_loss_weight * flow_loss
if run_text:
text_loss = self._compute_text_loss(batch, text_labels)
loss_dict["text_loss"] = float(text_loss.detach().item())
total = (
self.config.text_loss_weight * text_loss
if total is None
else total + self.config.text_loss_weight * text_loss
)
if total is None:
# Both flow and text disabled — make this an obvious bug
# rather than a silent zero loss.
raise RuntimeError(
"PI052Policy.forward: both flow_loss_weight and "
"text_loss_weight are 0 (or text_labels missing) — "
"nothing to train."
)
loss_dict["loss"] = float(total.detach().item()) if total.dim() == 0 else float("nan")
if reduction == "none":
return total.expand(batch[OBS_LANGUAGE_TOKENS].shape[0]), loss_dict
return total, loss_dict
# ------------------------------------------------------------------
# Text loss
# ------------------------------------------------------------------
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on PaliGemma's LM head over the supervised span.
Re-uses the same prefix-embedding path the flow head does:
embed images + state + language tokens, run a forward pass,
slice out the per-token logits at the supervised positions,
compute CE.
"""
from torch.nn import functional as F # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
tokens = batch[OBS_LANGUAGE_TOKENS]
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, tokens, masks
)
# PaliGemma's text path: forward the prefix through the
# backbone *without* the action expert. We piggy-back on the
# existing PaliGemmaWithExpertModel.forward — it accepts a
# list of expert inputs and returns parallel outputs.
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
# Logits over the vocab via the PaliGemma lm_head.
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
logits = lm_head(vlm_out.to(lm_head.weight.dtype))
# Shift for next-token prediction: predict token[i+1] from
# hidden[i]. Both ``logits`` and ``text_labels`` are over the
# same sequence length, so shift logits[:-1] vs labels[1:].
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = text_labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return loss
# ------------------------------------------------------------------
# select_message — AR text generation at inference
# ------------------------------------------------------------------
def select_message(
self,
batch: dict[str, Tensor],
*,
max_new_tokens: int = 128,
min_new_tokens: int = 0,
eos_token_id: int | None = None,
temperature: float = 0.0,
top_p: float = 1.0,
tokenizer: Any = None,
) -> str:
"""Generate text continuation from a multimodal prefix.
Mirrors ``SmolVLA2Policy.select_message`` so the same
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
can drive π0.5 v2 unchanged.
"""
self.eval()
if tokenizer is None:
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
special_ids: set[int] = set()
try:
for sid in (tokenizer.all_special_ids or []):
if sid is not None:
special_ids.add(int(sid))
except Exception: # noqa: BLE001
pass
if eos_token_id is not None:
special_ids.add(int(eos_token_id))
images, img_masks = self.model._preprocess_images(batch)
tokens = batch[OBS_LANGUAGE_TOKENS]
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, tokens, masks
)
device = prefix_embs.device
bsize = prefix_embs.shape[0]
emb_dim = prefix_embs.shape[-1]
text_emb_scale = math.sqrt(emb_dim)
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
current_embs = prefix_embs
current_pad = prefix_pad_masks
current_att = prefix_att_masks
generated: list[int] = []
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
backbone = self.model.paligemma_with_expert
lm_head = backbone.paligemma.lm_head
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[current_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
break
last = vlm_out[:, -1:].to(lm_head.weight.dtype)
logits_step = lm_head(last)[:, -1] # (B, V)
if special_ids and len(generated) < min_new_tokens:
for sid in special_ids:
logits_step[..., sid] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
if eos_token_id is not None and tok_id == eos_token_id:
break
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
new_emb = new_emb * text_emb_scale
current_embs = torch.cat([current_embs, new_emb], dim=1)
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
if not decoded and generated:
try:
self._last_select_message_debug = (
f"raw_ids={generated[:16]} "
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
)
except Exception: # noqa: BLE001
self._last_select_message_debug = f"raw_ids={generated[:16]}"
else:
self._last_select_message_debug = ""
return decoded
@staticmethod
def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor:
if temperature <= 0.0:
return logits.argmax(dim=-1)
scaled = logits / max(temperature, 1e-6)
probs = torch.softmax(scaled, dim=-1)
if top_p < 1.0:
sorted_p, sorted_ix = torch.sort(probs, descending=True, dim=-1)
cum = torch.cumsum(sorted_p, dim=-1)
mask = cum > top_p
mask[..., 0] = False
sorted_p = sorted_p.masked_fill(mask, 0.0)
sorted_p = sorted_p / sorted_p.sum(dim=-1, keepdim=True).clamp_min(1e-8)
choice = torch.multinomial(sorted_p, num_samples=1)
return sorted_ix.gather(-1, choice).squeeze(-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
@@ -0,0 +1,148 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 pre/post-processor factory.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
relative-action prep (inherited from π0.5)
NormalizerProcessorStep
RenderMessagesStep recipe messages, target_message_indices,
message_streams (PR 1 of the steerable
stack)
PI052TextTokenizerStep messages input_ids + label mask +
predict_actions
DeviceProcessorStep
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
so unannotated datasets keep working.
Post-processor is unchanged from π0.5.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
RenderMessagesStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
from .configuration_pi052 import PI052Config
from .text_processor_pi052 import PI052TextTokenizerStep
def make_pi052_pre_post_processors(
config: PI052Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build PI0.5-v2's pre/post-processor pipelines.
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
"""
if not config.recipe_path:
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
RenderMessagesStep(recipe=recipe),
PI052TextTokenizerStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AbsoluteActionsProcessorStep(
enabled=config.use_relative_actions,
relative_step=relative_step,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/`` (same lookup rules as
``make_smolvla2_pre_post_processors``).
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)
@@ -0,0 +1,303 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 text-tokenisation step.
PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
messages as plain text with simple ``User: ... Assistant: ...`` role
delimiters matching the prompt format π0.5 uses in the paper
(``Task: ... State: ... Action: ...``).
Outputs:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` the
concatenated prompt tokenised by the PaliGemma tokenizer (the same
one ``processor_pi05`` already uses).
* ``text_labels`` same shape as token ids, ``-100`` everywhere except
positions belonging to messages whose index is in
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
those positions via the PaliGemma ``lm_head``.
* ``predict_actions`` bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``. Same
semantics as the SmolVLA2 step.
"""
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
logger = logging.getLogger(__name__)
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
The recipe renderer can emit ``content`` as a string OR as a list
of HF-style multimodal blocks (``{type: text, text: ...}``,
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
only consume strings, so we flatten: drop image blocks (cameras
flow through ``observation.images.*`` separately) and join text
block texts.
"""
new = dict(message)
new.pop("stream", None)
new.pop("target", None)
content = new.get("content")
if content is None:
new["content"] = ""
elif isinstance(content, str):
pass
elif isinstance(content, list):
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
t = block.get("text", "")
if isinstance(t, str):
parts.append(t)
new["content"] = "\n".join(parts)
else:
new["content"] = str(content)
return new
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt.
Returns:
prompt: the full text the tokenizer will consume.
msg_spans: list of ``(char_start, char_end)`` covering each
message's content within ``prompt``. The
target-mask builder uses this to find the
character ranges belonging to the supervised
messages.
"""
parts: list[str] = []
spans: list[tuple[int, int]] = []
cursor = 0
for m in messages:
role = m.get("role", "user")
content = m.get("content", "") or ""
# Role tag + newline. The model has to learn to emit the same
# role tokens at generation time, which is fine for greedy
# decoding because the chat template is implicit in the
# supervised target span.
header = f"{role.capitalize()}: "
# span covers ONLY the content portion (so labels are computed
# over the supervised payload, not the role tag).
full = header + content + "\n"
start = cursor + len(header)
end = start + len(content)
parts.append(full)
spans.append((start, end))
cursor += len(full)
return "".join(parts), spans
@dataclass
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
class PI052TextTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template;
concatenates messages as ``User: ... \\nAssistant: ...`` text.
"""
tokenizer_name: str = "google/paligemma-3b-pt-224"
max_length: int = 200
padding: str = "max_length"
padding_side: str = "right"
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
dropout_seed: int | None = None
def __post_init__(self) -> None:
self._tokenizer: Any = None
def _ensure_tokenizer(self) -> Any:
if self._tokenizer is not None:
return self._tokenizer
from transformers import AutoTokenizer # noqa: PLC0415
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
return self._tokenizer
# ------------------------------------------------------------------
# Pipeline step
# ------------------------------------------------------------------
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
transition = transition.copy()
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
messages = complementary.get("messages") or []
target_indices = list(complementary.get("target_message_indices") or [])
message_streams = list(complementary.get("message_streams") or [])
if not messages:
# No recipe was rendered — caller will fall back to the
# plain Pi0.5 prompt path. We pass the transition through
# unmodified.
return transition
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
# after removal.
if (
self.plan_dropout_prob
or self.memory_dropout_prob
or self.subtask_dropout_prob
or self.interjection_dropout_prob
):
messages, target_indices = self._apply_prompt_dropout(
messages,
target_indices,
complementary,
)
messages = [_strip_blocks(m) for m in messages]
prompt, spans = _format_messages(messages)
tokenizer = self._ensure_tokenizer()
encoded = tokenizer(
prompt,
max_length=self.max_length,
padding=self.padding,
truncation=True,
return_tensors="pt",
return_offsets_mapping=True,
padding_side=self.padding_side,
)
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0].bool()
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
# Build label mask: -100 everywhere except over supervised
# target message char ranges.
labels = torch.full_like(input_ids, fill_value=-100)
for idx in target_indices:
if idx >= len(spans):
continue
char_start, char_end = spans[idx]
for token_pos in range(input_ids.shape[0]):
if not attention_mask[token_pos]:
continue
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
if tok_end <= char_start or tok_start >= char_end:
continue
labels[token_pos] = input_ids[token_pos]
predict_actions = torch.tensor(
bool(any(message_streams[i] == "low_level" for i in target_indices if i < len(message_streams))),
dtype=torch.bool,
)
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": labels.unsqueeze(0),
"predict_actions": predict_actions.unsqueeze(0),
}
return transition
# ------------------------------------------------------------------
# Per-component prompt dropout (Pi0.7 §V.E)
# ------------------------------------------------------------------
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
complementary: dict[str, Any],
) -> tuple[list[dict[str, Any]], list[int]]:
"""Drop messages classified as plan/memory/subtask context.
Targets are *never* dropped (they're the supervised payload).
Re-maps target_indices to the new positions after drops.
"""
import random # noqa: PLC0415
seed = self.dropout_seed
if seed is None:
seed_src = complementary.get("dataset_index") or complementary.get("frame_index") or 0
try:
seed = int(seed_src)
except (TypeError, ValueError):
seed = 0
rng = random.Random(seed)
keep_indices: list[int] = []
for idx, msg in enumerate(messages):
if idx in target_indices:
keep_indices.append(idx)
continue
kind = _classify_for_dropout(msg)
prob = {
"plan": self.plan_dropout_prob,
"memory": self.memory_dropout_prob,
"subtask": self.subtask_dropout_prob,
"interjection": self.interjection_dropout_prob,
}.get(kind, 0.0)
if prob > 0.0 and rng.random() < prob:
continue
keep_indices.append(idx)
# Build remap and apply
new_messages = [messages[i] for i in keep_indices]
old_to_new = {old: new for new, old in enumerate(keep_indices)}
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
return new_messages, new_targets
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
"""Heuristic content-prefix classifier — mirrors SmolVLA2's."""
content = message.get("content")
if isinstance(content, list):
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
content = " ".join(text_parts)
elif content is None:
return None
elif not isinstance(content, str):
return None
s = content.strip()
if s.startswith("Plan:") or s.startswith("Previous plan"):
return "plan"
if s.startswith("Memory:") or s.startswith("Previous memory"):
return "memory"
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
return "subtask"
return None