From ef5879a02a5554f91310770ba4c1a7e75d420e2b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 13 May 2026 10:59:26 +0200 Subject: [PATCH] =?UTF-8?q?feat(pi052):=20=CF=800.5=20v2=20=E2=80=94=20ful?= =?UTF-8?q?l=20reproduction=20of=20the=20=CF=800.5=20paper=20recipe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New ``lerobot.policies.pi052`` (parallel to ``smolvla2``) that adds text-prediction + hierarchical-inference on top of the existing π0.5 implementation. Mirrors the paper's §IV.D dual-head training: L = H(text) + α * ‖ω - a - f_θ_action(...)‖², α = 10 Components: * ``configuration_pi052.py`` thin PI05Config subclass; adds recipe_path, text/flow loss weights (default α=10 per paper), prompt dropout knobs, ``unfreeze_lm_head``. * ``text_processor_pi052.py`` PI052TextTokenizerStep — concatenates rendered messages as ``Role: ...`` plain text (PaliGemma has no chat template), tokenises with the PaliGemma tokenizer, builds a label mask covering supervised target spans. Includes Pi 0.7 §V.E per-component prompt dropout. * ``processor_pi052.py`` make_pi052_pre_post_processors — Rename + Batch + Relative + Normalize + RenderMessagesStep + PI052TextTokenizerStep + Device. Falls back to π0.5's plain pipeline when recipe_path is unset. * ``modeling_pi052.py`` PI052Policy(PI05Policy) — re-enables PaliGemma ``lm_head``, computes text_loss via CE on the supervised span, sums with flow_loss in forward(), and adds select_message for AR text generation at inference (same surface as SmolVLA2Policy.select_message so SmolVLA2Runtime drives it unchanged). Plus the supporting plumbing: * recipe ``configs/recipes/pi052_hirobot.yaml`` — same Hi-Robot blend as smolvla2_hirobot.yaml, with the same ``${subtask}`` / ``if_present`` supervision fix (current span at every frame, not ``${next_subtask}``). * SLURM ``examples/training/pi052_hirobot.slurm`` — full training command matching the SmolVLA2 launcher. * factory registration: ``--policy.type=pi052`` resolves to PI052Policy with the new processor. Same multi-rate runtime (``lerobot.policies.smolvla2.inference``) drives this policy too — both expose ``predict_action_chunk`` for the action expert and ``select_message`` for the LM head. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/training/pi052_hirobot.slurm | 75 ++++ .../configs/recipes/pi052_hirobot.yaml | 92 +++++ src/lerobot/policies/factory.py | 19 + src/lerobot/policies/pi052/__init__.py | 42 +++ .../policies/pi052/configuration_pi052.py | 109 ++++++ src/lerobot/policies/pi052/modeling_pi052.py | 339 ++++++++++++++++++ src/lerobot/policies/pi052/processor_pi052.py | 148 ++++++++ .../policies/pi052/text_processor_pi052.py | 303 ++++++++++++++++ 8 files changed, 1127 insertions(+) create mode 100644 examples/training/pi052_hirobot.slurm create mode 100644 src/lerobot/configs/recipes/pi052_hirobot.yaml create mode 100644 src/lerobot/policies/pi052/__init__.py create mode 100644 src/lerobot/policies/pi052/configuration_pi052.py create mode 100644 src/lerobot/policies/pi052/modeling_pi052.py create mode 100644 src/lerobot/policies/pi052/processor_pi052.py create mode 100644 src/lerobot/policies/pi052/text_processor_pi052.py diff --git a/examples/training/pi052_hirobot.slurm b/examples/training/pi052_hirobot.slurm new file mode 100644 index 000000000..e0a902177 --- /dev/null +++ b/examples/training/pi052_hirobot.slurm @@ -0,0 +1,75 @@ +#!/bin/bash +#SBATCH --job-name=pi052-hirobot +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=48:00:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=8 + +# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe. +# +# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml), +# just on the PaliGemma 2B + Gemma-300m action-expert backbone the +# paper uses. The text head learns subtask prediction via cross- +# entropy on supervised spans; the action expert learns the flow +# field. Paper §IV.D mixes the two losses with α=10, which we encode +# as flow_loss_weight=10 / text_loss_weight=1. + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}" +export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}" +export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}" + +DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}" +POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}" +JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +BATCH_SIZE="${BATCH_SIZE:-32}" +STEPS="${STEPS:-15000}" +RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}" + +echo "Training pi052 on $DATASET" +echo " GPUs: $NUM_PROCESSES" +echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))" +echo " steps: $STEPS" +echo " output: $OUTPUT_DIR" +echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1" +echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}" + +accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \ + -m lerobot.scripts.lerobot_train \ + --policy.type=pi052 \ + --policy.recipe_path=recipes/pi052_hirobot.yaml \ + --dataset.repo_id="$DATASET" \ + --dataset.revision=main \ + --dataset.video_backend=pyav \ + --output_dir="$OUTPUT_DIR" \ + --job_name="$JOB_NAME" \ + --policy.repo_id="$POLICY_REPO_ID" \ + --policy.compile_model=false \ + --policy.device=cuda \ + --policy.tokenizer_max_length=512 \ + --policy.text_loss_weight=1.0 \ + --policy.flow_loss_weight=10.0 \ + --policy.unfreeze_lm_head=true \ + --steps="$STEPS" \ + --policy.scheduler_decay_steps="$STEPS" \ + --batch_size="$BATCH_SIZE" \ + --wandb.enable=true \ + --wandb.disable_artifact=true \ + --wandb.project=hirobot \ + --log_freq=100 \ + --save_freq="$STEPS" \ + --num_workers=0 \ + --dataset.image_transforms.enable=true \ + --dataset.image_transforms.max_num_transforms=3 \ + --dataset.image_transforms.random_order=true \ + --policy.plan_dropout_prob=0.30 \ + --policy.memory_dropout_prob=0.30 \ + --policy.subtask_dropout_prob=0.20 diff --git a/src/lerobot/configs/recipes/pi052_hirobot.yaml b/src/lerobot/configs/recipes/pi052_hirobot.yaml new file mode 100644 index 000000000..f0c8982d6 --- /dev/null +++ b/src/lerobot/configs/recipes/pi052_hirobot.yaml @@ -0,0 +1,92 @@ +# π0.5 v2 — Hi-Robot / MEM / ECoT blend, reproducing the paper's +# hierarchical inference recipe on lerobot. +# +# Architecturally identical blend to ``smolvla2_hirobot.yaml`` — same +# five sub-recipes (memory_update, user_interjection_response, +# high_level_subtask, low_level_execution, ask_vqa_*) with the same +# message layouts. The only difference is which backbone the renderer's +# output is fed into: +# +# * SmolVLA2 calls SmolVLM's chat-template tokenizer +# (``apply_chat_template`` with chat-pretrained role markers). +# * π0.5 v2 concatenates the rendered messages as ``Role: content`` +# plain text, since PaliGemma is not chat-pretrained. See +# ``PI052TextTokenizerStep`` in ``policies/pi052/text_processor_pi052.py``. +# +# Same supervision target convention as ``smolvla2_hirobot.yaml``: the +# ``high_level_subtask`` recipe targets ``${subtask}`` (the *current* +# active span at every frame) rather than ``${next_subtask}`` (which +# is empty on stable phases and used to train the model to emit +# newlines). + +blend: + + memory_update: + weight: 0.10 + bindings: + prior_memory: "nth_prev(style=memory, offset=1)" + current_memory: "emitted_at(t, style=memory)" + completed_subtask: "nth_prev(style=subtask, offset=1)" + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory} + - {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask} + - {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory} + + user_interjection_response: + weight: 0.16 + bindings: + prior_plan: "nth_prev(style=plan, offset=1)" + current_plan: "emitted_at(t, style=plan)" + interjection: "emitted_at(t, style=interjection)" + speech: "emitted_at(t, role=assistant, tool_name=say)" + messages: + - {role: user, content: "${task}", stream: high_level} + - {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan} + - {role: user, content: "${interjection}", stream: high_level, if_present: interjection} + - {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech} + + # Pi 0.5 / Pi 0.7 supervision: predict the *current* active subtask + # at every frame from task + plan + memory + visual prefix. + # ``if_present: subtask`` skips frames with no active span instead of + # supervising an empty target (the failure mode that produces newline + # collapse). + high_level_subtask: + weight: 0.15 + messages: + - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} + - {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask} + + low_level_execution: + weight: 0.35 + messages: + - {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level} + - {role: assistant, content: "${subtask}", stream: low_level, target: true} + + ask_vqa_top: + weight: 0.10 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.front} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} + + ask_vqa_wrist: + weight: 0.10 + bindings: + vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)" + vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)" + messages: + - role: user + stream: high_level + if_present: vqa_query + content: + - {type: image, feature: observation.images.wrist} + - {type: text, text: "${vqa_query}"} + - {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa} diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index e718e0e48..47777f1a8 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -128,6 +128,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .pi05.modeling_pi05 import PI05Policy return PI05Policy + elif name == "pi052": + from .pi052.modeling_pi052 import PI052Policy + + return PI052Policy elif name == "sac": from .sac.modeling_sac import SACPolicy @@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi05": return PI05Config(**kwargs) + elif policy_type == "pi052": + from .pi052.configuration_pi052 import PI052Config + + return PI052Config(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla": @@ -370,6 +378,17 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif policy_cfg.type == "pi052": + # NOTE: PI052Config subclasses PI05Config, so this branch MUST + # come before the PI05Config isinstance check below (otherwise + # pi052 would silently pick up π0.5's processor). + from .pi052.processor_pi052 import make_pi052_pre_post_processors + + processors = make_pi052_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, PI05Config): from .pi05.processor_pi05 import make_pi05_pre_post_processors diff --git a/src/lerobot/policies/pi052/__init__.py b/src/lerobot/policies/pi052/__init__.py new file mode 100644 index 000000000..3e4c42f1c --- /dev/null +++ b/src/lerobot/policies/pi052/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical +inference recipe on lerobot. + +Extends :class:`lerobot.policies.pi05.PI05Policy` with: + +* recipe-driven training (PR 1's :class:`RenderMessagesStep`), +* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans + (the "high-level subtask prediction" of the paper, §IV.D), +* AR text generation at inference (:meth:`PI052Policy.select_message`), +* per-component prompt dropout (Pi 0.7 §V.E) for regularising the + text head against missing context at inference. + +See ``src/lerobot/configs/recipes/pi052_hirobot.yaml`` for the +canonical training recipe and +``examples/training/pi052_hirobot.slurm`` for the launcher. +""" + +from .configuration_pi052 import PI052Config +from .modeling_pi052 import PI052Policy +from .processor_pi052 import make_pi052_pre_post_processors +from .text_processor_pi052 import PI052TextTokenizerStep + +__all__ = [ + "PI052Config", + "PI052Policy", + "PI052TextTokenizerStep", + "make_pi052_pre_post_processors", +] diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py new file mode 100644 index 000000000..5cab001a9 --- /dev/null +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -0,0 +1,109 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's +hierarchical inference recipe. + +Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM + +~300M Gemma action expert, joint training with FAST tokens during +pre-train and flow matching during post-train), but with the +PaliGemma ``lm_head`` re-enabled so the same model can be supervised +to predict both: + + * **subtask strings** at the high level (cross-entropy on the LM + head), and + * **action chunks** at the low level (flow matching on the + action-expert tokens). + +This is the dual-head co-training pattern from the paper: + + L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖² + +with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits +inference into a text-prediction step followed by an action-prediction +step, which mirrors what ``SmolVLA2Runtime`` already does on a +SmolVLM2 backbone. +""" + +from dataclasses import dataclass + +from lerobot.configs import PreTrainedConfig + +from ..pi05.configuration_pi05 import PI05Config + + +@PreTrainedConfig.register_subclass("pi052") +@dataclass +class PI052Config(PI05Config): + """π0.5 with the PaliGemma LM head re-enabled for subtask prediction. + + See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head + config. Same recipe-driven training surface; the only differences + are which backbone the policy uses (PaliGemma here vs SmolVLM2 + there) and the default loss-weight scale (paper §IV.D uses + ``α=10`` between the two heads, which we encode as + ``flow_loss_weight=10, text_loss_weight=1``). + """ + + # Recipe / language stack --------------------------------------------- + recipe_path: str | None = "recipes/pi052_hirobot.yaml" + """Path (absolute or relative to ``src/lerobot/configs/``) to a + ``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend + shipped alongside this policy. Set to ``None`` to disable recipe + rendering and fall back to π0.5's single-task ``Task: ... Action:`` + prompt path (unannotated datasets keep working that way).""" + + apply_chat_template: bool = False + """PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a + chat template. So unlike SmolVLA2 we don't apply one. The recipe + renderer's output is concatenated as a plain prefix + assistant + suffix instead, mirroring how the π0.5 paper's high-level inference + samples text auto-regressively after the prefix.""" + + # Loss weights -------------------------------------------------------- + # Paper §IV.D: total = H(text) + α * MSE(flow), α = 10. We split + # the same total into two configurable knobs so individual scaling + # is recoverable. + text_loss_weight: float = 1.0 + """Weight on the LM-head cross-entropy term. Set to ``0`` to disable + text training entirely (reverts to flow-only / π0.5 behaviour).""" + + flow_loss_weight: float = 10.0 + """Weight on the action-expert flow-matching term. Default ``10.0`` + matches the paper's α.""" + + # Backbone training --------------------------------------------------- + unfreeze_lm_head: bool = True + """Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning. + The existing ``PI05Policy`` zeroes / freezes the head on load + because it never reads from it. Must be ``True`` for π0.5-style + hierarchical inference.""" + + # Per-component prompt dropout (Pi0.7 §V.E) --------------------------- + # Same regulariser surface as SmolVLA2: randomly drop non-target + # context messages so the LM head learns to handle missing / + # stale plan / memory at inference. Defaults to 0.0 so behaviour + # is identical until explicitly enabled. + plan_dropout_prob: float = 0.0 + memory_dropout_prob: float = 0.0 + subtask_dropout_prob: float = 0.0 + + def __post_init__(self) -> None: + super().__post_init__() + # Backbone needs gradients flowing through the text head when + # we're training it. Override the π0.5 default + # (``train_expert_only=True``) unless the user explicitly opts + # out of text training via ``text_loss_weight=0``. + if self.text_loss_weight > 0 and self.unfreeze_lm_head: + self.train_expert_only = False diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py new file mode 100644 index 000000000..7f6f7cc86 --- /dev/null +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -0,0 +1,339 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π0.5 v2 policy — dual-head training & hierarchical inference. + +A thin subclass of :class:`PI05Policy` that: + +* keeps the PaliGemma ``lm_head`` unfrozen during fine-tuning + (``PI05Policy`` zeroes / freezes it because it never reads from + the head; ``PI052Config.unfreeze_lm_head`` flips that), +* adds a ``text_loss`` term computed via cross-entropy on + ``text_labels`` (built by ``PI052TextTokenizerStep``), +* adds :meth:`select_message` for AR text generation at inference + (the high-level step in the π0.5 paper's two-stage inference loop), +* combines both losses in :meth:`forward` per Eq. (1) of the paper: + + L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(...)‖² + + with α controllable via ``config.flow_loss_weight``. + +The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see +``lerobot.policies.smolvla2.inference``) can drive this policy too — +both expose ``predict_action_chunk`` for the action expert and +``select_message`` for the LM head. +""" + +from __future__ import annotations + +import logging +import math +from typing import Any + +import torch +from torch import Tensor + +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS + +from ..pi05.configuration_pi05 import PI05Config +from ..pi05.modeling_pi05 import PI05Policy +from .configuration_pi052 import PI052Config + +logger = logging.getLogger(__name__) + + +class PI052Policy(PI05Policy): + """π0.5 with the PaliGemma LM head re-enabled.""" + + config_class = PI052Config + name = "pi052" + + def __init__(self, config: PI052Config, **kwargs: Any) -> None: + super().__init__(config, **kwargs) + # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and + # freezes a few terminal layers when ``train_expert_only`` is + # the (default) True. We re-enable the head if the user + # wants text supervision. + if config.text_loss_weight > 0 and config.unfreeze_lm_head: + self._unfreeze_lm_head() + + # ------------------------------------------------------------------ + # Head unfreeze helper + # ------------------------------------------------------------------ + + def _unfreeze_lm_head(self) -> None: + """Walk the PaliGemma submodules and re-enable gradients on + ``lm_head`` + the immediately preceding norm / last text-model + layer that ``PI05Policy`` typically freezes.""" + backbone = self.model.paligemma_with_expert.paligemma + if hasattr(backbone, "lm_head"): + for p in backbone.lm_head.parameters(): + p.requires_grad_(True) + # The text model's final norm and last transformer block — + # mirror SmolVLA2's logic, which finds these dynamically by + # the trainable=False parameters that point at the head's + # neighbourhood. + text_model = getattr(backbone, "model", None) + text_model = getattr(text_model, "language_model", text_model) + if text_model is None: + return + norm = getattr(text_model, "norm", None) + if norm is not None: + for p in norm.parameters(): + p.requires_grad_(True) + layers = getattr(text_model, "layers", None) + if isinstance(layers, (list, torch.nn.ModuleList)) and len(layers) > 0: + for p in layers[-1].parameters(): + p.requires_grad_(True) + + # ------------------------------------------------------------------ + # Forward (dual loss: flow + text) + # ------------------------------------------------------------------ + + def forward( + self, + batch: dict[str, Tensor], + reduction: str = "mean", + ) -> tuple[Tensor, dict]: + """Dual-head forward: flow-matching loss + text-CE loss. + + When ``text_labels`` isn't present in the batch (e.g. the + recipe wasn't applied), we delegate to ``PI05Policy.forward`` + unchanged. Otherwise we compute both losses and sum them with + ``flow_loss_weight`` / ``text_loss_weight``. + """ + text_labels = batch.get("text_labels") + predict_actions_t = batch.get("predict_actions") + + run_flow = ( + self.config.flow_loss_weight > 0 + and (predict_actions_t is None or bool(predict_actions_t.any().item())) + ) + run_text = self.config.text_loss_weight > 0 and text_labels is not None + + loss_dict: dict[str, Any] = {} + total: Tensor | None = None + + if run_flow: + flow_loss, flow_dict = super().forward(batch, reduction=reduction) + for k, v in flow_dict.items(): + loss_dict[f"flow_{k}"] = v + loss_dict["flow_loss"] = ( + flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan") + ) + total = self.config.flow_loss_weight * flow_loss + + if run_text: + text_loss = self._compute_text_loss(batch, text_labels) + loss_dict["text_loss"] = float(text_loss.detach().item()) + total = ( + self.config.text_loss_weight * text_loss + if total is None + else total + self.config.text_loss_weight * text_loss + ) + + if total is None: + # Both flow and text disabled — make this an obvious bug + # rather than a silent zero loss. + raise RuntimeError( + "PI052Policy.forward: both flow_loss_weight and " + "text_loss_weight are 0 (or text_labels missing) — " + "nothing to train." + ) + + loss_dict["loss"] = float(total.detach().item()) if total.dim() == 0 else float("nan") + if reduction == "none": + return total.expand(batch[OBS_LANGUAGE_TOKENS].shape[0]), loss_dict + return total, loss_dict + + # ------------------------------------------------------------------ + # Text loss + # ------------------------------------------------------------------ + + def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: + """Cross-entropy on PaliGemma's LM head over the supervised span. + + Re-uses the same prefix-embedding path the flow head does: + embed images + state + language tokens, run a forward pass, + slice out the per-token logits at the supervised positions, + compute CE. + """ + from torch.nn import functional as F # noqa: PLC0415 + + images, img_masks = self.model._preprocess_images(batch) + tokens = batch[OBS_LANGUAGE_TOKENS] + masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( + images, img_masks, tokens, masks + ) + # PaliGemma's text path: forward the prefix through the + # backbone *without* the action expert. We piggy-back on the + # existing PaliGemmaWithExpertModel.forward — it accepts a + # list of expert inputs and returns parallel outputs. + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + (vlm_out, _), _ = self.model.paligemma_with_expert.forward( + attention_mask=att_2d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=False, + fill_kv_cache=True, + ) + if vlm_out is None: + raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.") + + # Logits over the vocab via the PaliGemma lm_head. + lm_head = self.model.paligemma_with_expert.paligemma.lm_head + logits = lm_head(vlm_out.to(lm_head.weight.dtype)) + + # Shift for next-token prediction: predict token[i+1] from + # hidden[i]. Both ``logits`` and ``text_labels`` are over the + # same sequence length, so shift logits[:-1] vs labels[1:]. + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = text_labels[..., 1:].contiguous() + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ignore_index=-100, + ) + return loss + + # ------------------------------------------------------------------ + # select_message — AR text generation at inference + # ------------------------------------------------------------------ + + def select_message( + self, + batch: dict[str, Tensor], + *, + max_new_tokens: int = 128, + min_new_tokens: int = 0, + eos_token_id: int | None = None, + temperature: float = 0.0, + top_p: float = 1.0, + tokenizer: Any = None, + ) -> str: + """Generate text continuation from a multimodal prefix. + + Mirrors ``SmolVLA2Policy.select_message`` so the same + :class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime` + can drive π0.5 v2 unchanged. + """ + self.eval() + + if tokenizer is None: + from transformers import AutoTokenizer # noqa: PLC0415 + + tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + if eos_token_id is None: + eos_token_id = tokenizer.eos_token_id + + special_ids: set[int] = set() + try: + for sid in (tokenizer.all_special_ids or []): + if sid is not None: + special_ids.add(int(sid)) + except Exception: # noqa: BLE001 + pass + if eos_token_id is not None: + special_ids.add(int(eos_token_id)) + + images, img_masks = self.model._preprocess_images(batch) + tokens = batch[OBS_LANGUAGE_TOKENS] + masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( + images, img_masks, tokens, masks + ) + + device = prefix_embs.device + bsize = prefix_embs.shape[0] + emb_dim = prefix_embs.shape[-1] + text_emb_scale = math.sqrt(emb_dim) + ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) + + current_embs = prefix_embs + current_pad = prefix_pad_masks + current_att = prefix_att_masks + generated: list[int] = [] + + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + backbone = self.model.paligemma_with_expert + lm_head = backbone.paligemma.lm_head + + for _ in range(max_new_tokens): + att_2d = make_att_2d_masks(current_pad, current_att) + position_ids = torch.cumsum(current_pad, dim=1) - 1 + (vlm_out, _), _ = backbone.forward( + attention_mask=att_2d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[current_embs, None], + use_cache=False, + fill_kv_cache=True, + ) + if vlm_out is None: + break + last = vlm_out[:, -1:].to(lm_head.weight.dtype) + logits_step = lm_head(last)[:, -1] # (B, V) + if special_ids and len(generated) < min_new_tokens: + for sid in special_ids: + logits_step[..., sid] = float("-inf") + next_ids = self._sample_next_token(logits_step, temperature, top_p) + tok_id = int(next_ids[0].item()) + generated.append(tok_id) + if eos_token_id is not None and tok_id == eos_token_id: + break + + new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0)) + new_emb = new_emb * text_emb_scale + current_embs = torch.cat([current_embs, new_emb], dim=1) + current_pad = torch.cat([current_pad, ones_step], dim=1) + current_att = torch.cat([current_att, ones_step], dim=1) + + decoded = tokenizer.decode(generated, skip_special_tokens=True).strip() + if not decoded and generated: + try: + self._last_select_message_debug = ( + f"raw_ids={generated[:16]} " + f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}" + ) + except Exception: # noqa: BLE001 + self._last_select_message_debug = f"raw_ids={generated[:16]}" + else: + self._last_select_message_debug = "" + return decoded + + @staticmethod + def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor: + if temperature <= 0.0: + return logits.argmax(dim=-1) + scaled = logits / max(temperature, 1e-6) + probs = torch.softmax(scaled, dim=-1) + if top_p < 1.0: + sorted_p, sorted_ix = torch.sort(probs, descending=True, dim=-1) + cum = torch.cumsum(sorted_p, dim=-1) + mask = cum > top_p + mask[..., 0] = False + sorted_p = sorted_p.masked_fill(mask, 0.0) + sorted_p = sorted_p / sorted_p.sum(dim=-1, keepdim=True).clamp_min(1e-8) + choice = torch.multinomial(sorted_p, num_samples=1) + return sorted_ix.gather(-1, choice).squeeze(-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) diff --git a/src/lerobot/policies/pi052/processor_pi052.py b/src/lerobot/policies/pi052/processor_pi052.py new file mode 100644 index 000000000..6abe1cdcd --- /dev/null +++ b/src/lerobot/policies/pi052/processor_pi052.py @@ -0,0 +1,148 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π0.5 v2 pre/post-processor factory. + +When ``config.recipe_path`` is set, the pre-processor pipeline becomes: + + rename observations + add batch dim + relative-action prep (inherited from π0.5) + NormalizerProcessorStep + RenderMessagesStep — recipe → messages, target_message_indices, + message_streams (PR 1 of the steerable + stack) + PI052TextTokenizerStep — messages → input_ids + label mask + + predict_actions + DeviceProcessorStep + +When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline +so unannotated datasets keep working. + +Post-processor is unchanged from π0.5. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import torch + +from lerobot.configs.recipe import TrainingRecipe +from lerobot.processor import ( + AbsoluteActionsProcessorStep, + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RelativeActionsProcessorStep, + RenameObservationsProcessorStep, + RenderMessagesStep, + UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + +from ..pi05.processor_pi05 import make_pi05_pre_post_processors +from .configuration_pi052 import PI052Config +from .text_processor_pi052 import PI052TextTokenizerStep + + +def make_pi052_pre_post_processors( + config: PI052Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Build PI0.5-v2's pre/post-processor pipelines. + + Falls through to π0.5's stock pipeline when ``recipe_path`` is unset. + """ + if not config.recipe_path: + return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats) + + recipe = _load_recipe(config.recipe_path) + + relative_step = RelativeActionsProcessorStep( + enabled=config.use_relative_actions, + exclude_joints=getattr(config, "relative_exclude_joints", []), + action_names=getattr(config, "action_feature_names", None), + ) + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + relative_step, + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + RenderMessagesStep(recipe=recipe), + PI052TextTokenizerStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0), + memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0), + subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0), + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + AbsoluteActionsProcessorStep( + enabled=config.use_relative_actions, + relative_step=relative_step, + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +def _load_recipe(path_str: str) -> TrainingRecipe: + """Resolve ``path_str`` to a ``TrainingRecipe``. + + Accepts an absolute path or a path relative to + ``src/lerobot/configs/`` (same lookup rules as + ``make_smolvla2_pre_post_processors``). + """ + p = Path(path_str) + if not p.is_absolute() and not p.exists(): + from lerobot.configs import recipe as _recipe_module # noqa: PLC0415 + + configs_dir = Path(_recipe_module.__file__).resolve().parent + candidate = configs_dir / path_str + if candidate.exists(): + p = candidate + return TrainingRecipe.from_yaml(p) diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py new file mode 100644 index 000000000..7dfb5dfc3 --- /dev/null +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -0,0 +1,303 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""π0.5 v2 text-tokenisation step. + +PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on +``tokenizer.apply_chat_template``. Instead we concatenate the rendered +messages as plain text with simple ``User: ... Assistant: ...`` role +delimiters — matching the prompt format π0.5 uses in the paper +(``Task: ... State: ... Action: ...``). + +Outputs: + +* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the + concatenated prompt tokenised by the PaliGemma tokenizer (the same + one ``processor_pi05`` already uses). +* ``text_labels`` — same shape as token ids, ``-100`` everywhere except + positions belonging to messages whose index is in + ``target_message_indices``. ``modeling_pi052`` runs cross-entropy on + those positions via the PaliGemma ``lm_head``. +* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered + target messages has ``message_streams[i] == "low_level"``. Same + semantics as the SmolVLA2 step. +""" + +from __future__ import annotations + +import copy +import logging +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry +from lerobot.types import EnvTransition, TransitionKey +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS + +logger = logging.getLogger(__name__) + + +def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]: + """Normalise a message's content to a plain string. + + The recipe renderer can emit ``content`` as a string OR as a list + of HF-style multimodal blocks (``{type: text, text: ...}``, + ``{type: image, feature: ...}``). PaliGemma's text tokenizer can + only consume strings, so we flatten: drop image blocks (cameras + flow through ``observation.images.*`` separately) and join text + block texts. + """ + new = dict(message) + new.pop("stream", None) + new.pop("target", None) + content = new.get("content") + if content is None: + new["content"] = "" + elif isinstance(content, str): + pass + elif isinstance(content, list): + parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "text": + t = block.get("text", "") + if isinstance(t, str): + parts.append(t) + new["content"] = "\n".join(parts) + else: + new["content"] = str(content) + return new + + +def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]: + """Concatenate messages into the π0.5-style flat prompt. + + Returns: + prompt: the full text the tokenizer will consume. + msg_spans: list of ``(char_start, char_end)`` covering each + message's content within ``prompt``. The + target-mask builder uses this to find the + character ranges belonging to the supervised + messages. + """ + parts: list[str] = [] + spans: list[tuple[int, int]] = [] + cursor = 0 + for m in messages: + role = m.get("role", "user") + content = m.get("content", "") or "" + # Role tag + newline. The model has to learn to emit the same + # role tokens at generation time, which is fine for greedy + # decoding because the chat template is implicit in the + # supervised target span. + header = f"{role.capitalize()}: " + # span covers ONLY the content portion (so labels are computed + # over the supervised payload, not the role tag). + full = header + content + "\n" + start = cursor + len(header) + end = start + len(content) + parts.append(full) + spans.append((start, end)) + cursor += len(full) + return "".join(parts), spans + + +@dataclass +@ProcessorStepRegistry.register(name="pi052_text_tokenizer") +class PI052TextTokenizerStep(ProcessorStep): + """Render messages → token ids + label mask + predict_actions flag. + + π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template; + concatenates messages as ``User: ... \\nAssistant: ...`` text. + """ + + tokenizer_name: str = "google/paligemma-3b-pt-224" + max_length: int = 200 + padding: str = "max_length" + padding_side: str = "right" + plan_dropout_prob: float = 0.0 + memory_dropout_prob: float = 0.0 + subtask_dropout_prob: float = 0.0 + interjection_dropout_prob: float = 0.0 + dropout_seed: int | None = None + + def __post_init__(self) -> None: + self._tokenizer: Any = None + + def _ensure_tokenizer(self) -> Any: + if self._tokenizer is not None: + return self._tokenizer + from transformers import AutoTokenizer # noqa: PLC0415 + + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + return self._tokenizer + + # ------------------------------------------------------------------ + # Pipeline step + # ------------------------------------------------------------------ + + def __call__(self, transition: EnvTransition) -> EnvTransition | None: + transition = transition.copy() + complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {} + messages = complementary.get("messages") or [] + target_indices = list(complementary.get("target_message_indices") or []) + message_streams = list(complementary.get("message_streams") or []) + + if not messages: + # No recipe was rendered — caller will fall back to the + # plain Pi0.5 prompt path. We pass the transition through + # unmodified. + return transition + + # Optional: drop non-target messages per the dropout config. + # Keeps the supervised-target indices stable by re-mapping + # after removal. + if ( + self.plan_dropout_prob + or self.memory_dropout_prob + or self.subtask_dropout_prob + or self.interjection_dropout_prob + ): + messages, target_indices = self._apply_prompt_dropout( + messages, + target_indices, + complementary, + ) + + messages = [_strip_blocks(m) for m in messages] + prompt, spans = _format_messages(messages) + + tokenizer = self._ensure_tokenizer() + encoded = tokenizer( + prompt, + max_length=self.max_length, + padding=self.padding, + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + padding_side=self.padding_side, + ) + + input_ids = encoded["input_ids"][0] + attention_mask = encoded["attention_mask"][0].bool() + offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end) + + # Build label mask: -100 everywhere except over supervised + # target message char ranges. + labels = torch.full_like(input_ids, fill_value=-100) + for idx in target_indices: + if idx >= len(spans): + continue + char_start, char_end = spans[idx] + for token_pos in range(input_ids.shape[0]): + if not attention_mask[token_pos]: + continue + tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1]) + if tok_end <= char_start or tok_start >= char_end: + continue + labels[token_pos] = input_ids[token_pos] + + predict_actions = torch.tensor( + bool(any(message_streams[i] == "low_level" for i in target_indices if i < len(message_streams))), + dtype=torch.bool, + ) + + obs = dict(transition.get(TransitionKey.OBSERVATION) or {}) + obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0) + obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0) + transition[TransitionKey.OBSERVATION] = obs + + transition[TransitionKey.COMPLEMENTARY_DATA] = { + **complementary, + "text_labels": labels.unsqueeze(0), + "predict_actions": predict_actions.unsqueeze(0), + } + return transition + + # ------------------------------------------------------------------ + # Per-component prompt dropout (Pi0.7 §V.E) + # ------------------------------------------------------------------ + + def _apply_prompt_dropout( + self, + messages: list[dict[str, Any]], + target_indices: list[int], + complementary: dict[str, Any], + ) -> tuple[list[dict[str, Any]], list[int]]: + """Drop messages classified as plan/memory/subtask context. + + Targets are *never* dropped (they're the supervised payload). + Re-maps target_indices to the new positions after drops. + """ + import random # noqa: PLC0415 + + seed = self.dropout_seed + if seed is None: + seed_src = complementary.get("dataset_index") or complementary.get("frame_index") or 0 + try: + seed = int(seed_src) + except (TypeError, ValueError): + seed = 0 + rng = random.Random(seed) + + keep_indices: list[int] = [] + for idx, msg in enumerate(messages): + if idx in target_indices: + keep_indices.append(idx) + continue + kind = _classify_for_dropout(msg) + prob = { + "plan": self.plan_dropout_prob, + "memory": self.memory_dropout_prob, + "subtask": self.subtask_dropout_prob, + "interjection": self.interjection_dropout_prob, + }.get(kind, 0.0) + if prob > 0.0 and rng.random() < prob: + continue + keep_indices.append(idx) + + # Build remap and apply + new_messages = [messages[i] for i in keep_indices] + old_to_new = {old: new for new, old in enumerate(keep_indices)} + new_targets = [old_to_new[t] for t in target_indices if t in old_to_new] + return new_messages, new_targets + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +def _classify_for_dropout(message: dict[str, Any]) -> str | None: + """Heuristic content-prefix classifier — mirrors SmolVLA2's.""" + content = message.get("content") + if isinstance(content, list): + text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"] + content = " ".join(text_parts) + elif content is None: + return None + elif not isinstance(content, str): + return None + s = content.strip() + if s.startswith("Plan:") or s.startswith("Previous plan"): + return "plan" + if s.startswith("Memory:") or s.startswith("Previous memory"): + return "memory" + if s.startswith("Current subtask") or s.startswith("Completed subtask"): + return "subtask" + return None