mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
feat(pi052): π0.5 v2 — full reproduction of the π0.5 paper recipe
New ``lerobot.policies.pi052`` (parallel to ``smolvla2``) that adds
text-prediction + hierarchical-inference on top of the existing π0.5
implementation. Mirrors the paper's §IV.D dual-head training:
L = H(text) + α * ‖ω - a - f_θ_action(...)‖², α = 10
Components:
* ``configuration_pi052.py`` thin PI05Config subclass; adds
recipe_path, text/flow loss weights
(default α=10 per paper), prompt
dropout knobs, ``unfreeze_lm_head``.
* ``text_processor_pi052.py`` PI052TextTokenizerStep — concatenates
rendered messages as ``Role: ...``
plain text (PaliGemma has no chat
template), tokenises with the
PaliGemma tokenizer, builds a label
mask covering supervised target
spans. Includes Pi 0.7 §V.E
per-component prompt dropout.
* ``processor_pi052.py`` make_pi052_pre_post_processors —
Rename + Batch + Relative +
Normalize + RenderMessagesStep +
PI052TextTokenizerStep + Device.
Falls back to π0.5's plain pipeline
when recipe_path is unset.
* ``modeling_pi052.py`` PI052Policy(PI05Policy) — re-enables
PaliGemma ``lm_head``, computes
text_loss via CE on the supervised
span, sums with flow_loss in
forward(), and adds select_message
for AR text generation at inference
(same surface as
SmolVLA2Policy.select_message so
SmolVLA2Runtime drives it unchanged).
Plus the supporting plumbing:
* recipe ``configs/recipes/pi052_hirobot.yaml`` — same Hi-Robot blend
as smolvla2_hirobot.yaml, with the same ``${subtask}`` /
``if_present`` supervision fix (current span at every frame, not
``${next_subtask}``).
* SLURM ``examples/training/pi052_hirobot.slurm`` — full training
command matching the SmolVLA2 launcher.
* factory registration: ``--policy.type=pi052`` resolves to
PI052Policy with the new processor.
Same multi-rate runtime (``lerobot.policies.smolvla2.inference``)
drives this policy too — both expose ``predict_action_chunk`` for the
action expert and ``select_message`` for the LM head.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=pi052-hirobot
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=8
|
||||
|
||||
# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe.
|
||||
#
|
||||
# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml),
|
||||
# just on the PaliGemma 2B + Gemma-300m action-expert backbone the
|
||||
# paper uses. The text head learns subtask prediction via cross-
|
||||
# entropy on supervised spans; the action expert learns the flow
|
||||
# field. Paper §IV.D mixes the two losses with α=10, which we encode
|
||||
# as flow_loss_weight=10 / text_loss_weight=1.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
|
||||
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
|
||||
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
|
||||
|
||||
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
|
||||
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}"
|
||||
JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}"
|
||||
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-32}"
|
||||
STEPS="${STEPS:-15000}"
|
||||
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}"
|
||||
|
||||
echo "Training pi052 on $DATASET"
|
||||
echo " GPUs: $NUM_PROCESSES"
|
||||
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
|
||||
echo " steps: $STEPS"
|
||||
echo " output: $OUTPUT_DIR"
|
||||
echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1"
|
||||
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
|
||||
|
||||
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi052 \
|
||||
--policy.recipe_path=recipes/pi052_hirobot.yaml \
|
||||
--dataset.repo_id="$DATASET" \
|
||||
--dataset.revision=main \
|
||||
--dataset.video_backend=pyav \
|
||||
--output_dir="$OUTPUT_DIR" \
|
||||
--job_name="$JOB_NAME" \
|
||||
--policy.repo_id="$POLICY_REPO_ID" \
|
||||
--policy.compile_model=false \
|
||||
--policy.device=cuda \
|
||||
--policy.tokenizer_max_length=512 \
|
||||
--policy.text_loss_weight=1.0 \
|
||||
--policy.flow_loss_weight=10.0 \
|
||||
--policy.unfreeze_lm_head=true \
|
||||
--steps="$STEPS" \
|
||||
--policy.scheduler_decay_steps="$STEPS" \
|
||||
--batch_size="$BATCH_SIZE" \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=hirobot \
|
||||
--log_freq=100 \
|
||||
--save_freq="$STEPS" \
|
||||
--num_workers=0 \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.image_transforms.max_num_transforms=3 \
|
||||
--dataset.image_transforms.random_order=true \
|
||||
--policy.plan_dropout_prob=0.30 \
|
||||
--policy.memory_dropout_prob=0.30 \
|
||||
--policy.subtask_dropout_prob=0.20
|
||||
@@ -0,0 +1,92 @@
|
||||
# π0.5 v2 — Hi-Robot / MEM / ECoT blend, reproducing the paper's
|
||||
# hierarchical inference recipe on lerobot.
|
||||
#
|
||||
# Architecturally identical blend to ``smolvla2_hirobot.yaml`` — same
|
||||
# five sub-recipes (memory_update, user_interjection_response,
|
||||
# high_level_subtask, low_level_execution, ask_vqa_*) with the same
|
||||
# message layouts. The only difference is which backbone the renderer's
|
||||
# output is fed into:
|
||||
#
|
||||
# * SmolVLA2 calls SmolVLM's chat-template tokenizer
|
||||
# (``apply_chat_template`` with chat-pretrained role markers).
|
||||
# * π0.5 v2 concatenates the rendered messages as ``Role: content``
|
||||
# plain text, since PaliGemma is not chat-pretrained. See
|
||||
# ``PI052TextTokenizerStep`` in ``policies/pi052/text_processor_pi052.py``.
|
||||
#
|
||||
# Same supervision target convention as ``smolvla2_hirobot.yaml``: the
|
||||
# ``high_level_subtask`` recipe targets ``${subtask}`` (the *current*
|
||||
# active span at every frame) rather than ``${next_subtask}`` (which
|
||||
# is empty on stable phases and used to train the model to emit
|
||||
# newlines).
|
||||
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
# Pi 0.5 / Pi 0.7 supervision: predict the *current* active subtask
|
||||
# at every frame from task + plan + memory + visual prefix.
|
||||
# ``if_present: subtask`` skips frames with no active span instead of
|
||||
# supervising an empty target (the failure mode that produces newline
|
||||
# collapse).
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -128,6 +128,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi052":
|
||||
from .pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
return PI052Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi052":
|
||||
from .pi052.configuration_pi052 import PI052Config
|
||||
|
||||
return PI052Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -370,6 +378,17 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi052":
|
||||
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
|
||||
# come before the PI05Config isinstance check below (otherwise
|
||||
# pi052 would silently pick up π0.5's processor).
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from .pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
|
||||
inference recipe on lerobot.
|
||||
|
||||
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
|
||||
|
||||
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
|
||||
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
|
||||
(the "high-level subtask prediction" of the paper, §IV.D),
|
||||
* AR text generation at inference (:meth:`PI052Policy.select_message`),
|
||||
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
|
||||
text head against missing context at inference.
|
||||
|
||||
See ``src/lerobot/configs/recipes/pi052_hirobot.yaml`` for the
|
||||
canonical training recipe and
|
||||
``examples/training/pi052_hirobot.slurm`` for the launcher.
|
||||
"""
|
||||
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .modeling_pi052 import PI052Policy
|
||||
from .processor_pi052 import make_pi052_pre_post_processors
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
__all__ = [
|
||||
"PI052Config",
|
||||
"PI052Policy",
|
||||
"PI052TextTokenizerStep",
|
||||
"make_pi052_pre_post_processors",
|
||||
]
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
|
||||
hierarchical inference recipe.
|
||||
|
||||
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
|
||||
~300M Gemma action expert, joint training with FAST tokens during
|
||||
pre-train and flow matching during post-train), but with the
|
||||
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
|
||||
to predict both:
|
||||
|
||||
* **subtask strings** at the high level (cross-entropy on the LM
|
||||
head), and
|
||||
* **action chunks** at the low level (flow matching on the
|
||||
action-expert tokens).
|
||||
|
||||
This is the dual-head co-training pattern from the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖²
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which mirrors what ``SmolVLA2Runtime`` already does on a
|
||||
SmolVLM2 backbone.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi052")
|
||||
@dataclass
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head
|
||||
config. Same recipe-driven training surface; the only differences
|
||||
are which backbone the policy uses (PaliGemma here vs SmolVLM2
|
||||
there) and the default loss-weight scale (paper §IV.D uses
|
||||
``α=10`` between the two heads, which we encode as
|
||||
``flow_loss_weight=10, text_loss_weight=1``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/pi052_hirobot.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
|
||||
shipped alongside this policy. Set to ``None`` to disable recipe
|
||||
rendering and fall back to π0.5's single-task ``Task: ... Action:``
|
||||
prompt path (unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template. So unlike SmolVLA2 we don't apply one. The recipe
|
||||
renderer's output is concatenated as a plain prefix + assistant
|
||||
suffix instead, mirroring how the π0.5 paper's high-level inference
|
||||
samples text auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D: total = H(text) + α * MSE(flow), α = 10. We split
|
||||
# the same total into two configurable knobs so individual scaling
|
||||
# is recoverable.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 10.0
|
||||
"""Weight on the action-expert flow-matching term. Default ``10.0``
|
||||
matches the paper's α."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
|
||||
The existing ``PI05Policy`` zeroes / freezes the head on load
|
||||
because it never reads from it. Must be ``True`` for π0.5-style
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Same regulariser surface as SmolVLA2: randomly drop non-target
|
||||
# context messages so the LM head learns to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
# we're training it. Override the π0.5 default
|
||||
# (``train_expert_only=True``) unless the user explicitly opts
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
@@ -0,0 +1,339 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 policy — dual-head training & hierarchical inference.
|
||||
|
||||
A thin subclass of :class:`PI05Policy` that:
|
||||
|
||||
* keeps the PaliGemma ``lm_head`` unfrozen during fine-tuning
|
||||
(``PI05Policy`` zeroes / freezes it because it never reads from
|
||||
the head; ``PI052Config.unfreeze_lm_head`` flips that),
|
||||
* adds a ``text_loss`` term computed via cross-entropy on
|
||||
``text_labels`` (built by ``PI052TextTokenizerStep``),
|
||||
* adds :meth:`select_message` for AR text generation at inference
|
||||
(the high-level step in the π0.5 paper's two-stage inference loop),
|
||||
* combines both losses in :meth:`forward` per Eq. (1) of the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(...)‖²
|
||||
|
||||
with α controllable via ``config.flow_loss_weight``.
|
||||
|
||||
The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see
|
||||
``lerobot.policies.smolvla2.inference``) can drive this policy too —
|
||||
both expose ``predict_action_chunk`` for the action expert and
|
||||
``select_message`` for the LM head.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
from ..pi05.modeling_pi05 import PI05Policy
|
||||
from .configuration_pi052 import PI052Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PI052Policy(PI05Policy):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled."""
|
||||
|
||||
config_class = PI052Config
|
||||
name = "pi052"
|
||||
|
||||
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
||||
super().__init__(config, **kwargs)
|
||||
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
||||
# freezes a few terminal layers when ``train_expert_only`` is
|
||||
# the (default) True. We re-enable the head if the user
|
||||
# wants text supervision.
|
||||
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Head unfreeze helper
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Walk the PaliGemma submodules and re-enable gradients on
|
||||
``lm_head`` + the immediately preceding norm / last text-model
|
||||
layer that ``PI05Policy`` typically freezes."""
|
||||
backbone = self.model.paligemma_with_expert.paligemma
|
||||
if hasattr(backbone, "lm_head"):
|
||||
for p in backbone.lm_head.parameters():
|
||||
p.requires_grad_(True)
|
||||
# The text model's final norm and last transformer block —
|
||||
# mirror SmolVLA2's logic, which finds these dynamically by
|
||||
# the trainable=False parameters that point at the head's
|
||||
# neighbourhood.
|
||||
text_model = getattr(backbone, "model", None)
|
||||
text_model = getattr(text_model, "language_model", text_model)
|
||||
if text_model is None:
|
||||
return
|
||||
norm = getattr(text_model, "norm", None)
|
||||
if norm is not None:
|
||||
for p in norm.parameters():
|
||||
p.requires_grad_(True)
|
||||
layers = getattr(text_model, "layers", None)
|
||||
if isinstance(layers, (list, torch.nn.ModuleList)) and len(layers) > 0:
|
||||
for p in layers[-1].parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward (dual loss: flow + text)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""Dual-head forward: flow-matching loss + text-CE loss.
|
||||
|
||||
When ``text_labels`` isn't present in the batch (e.g. the
|
||||
recipe wasn't applied), we delegate to ``PI05Policy.forward``
|
||||
unchanged. Otherwise we compute both losses and sum them with
|
||||
``flow_loss_weight`` / ``text_loss_weight``.
|
||||
"""
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
run_flow = (
|
||||
self.config.flow_loss_weight > 0
|
||||
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
||||
)
|
||||
run_text = self.config.text_loss_weight > 0 and text_labels is not None
|
||||
|
||||
loss_dict: dict[str, Any] = {}
|
||||
total: Tensor | None = None
|
||||
|
||||
if run_flow:
|
||||
flow_loss, flow_dict = super().forward(batch, reduction=reduction)
|
||||
for k, v in flow_dict.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
loss_dict["flow_loss"] = (
|
||||
flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan")
|
||||
)
|
||||
total = self.config.flow_loss_weight * flow_loss
|
||||
|
||||
if run_text:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
total = (
|
||||
self.config.text_loss_weight * text_loss
|
||||
if total is None
|
||||
else total + self.config.text_loss_weight * text_loss
|
||||
)
|
||||
|
||||
if total is None:
|
||||
# Both flow and text disabled — make this an obvious bug
|
||||
# rather than a silent zero loss.
|
||||
raise RuntimeError(
|
||||
"PI052Policy.forward: both flow_loss_weight and "
|
||||
"text_loss_weight are 0 (or text_labels missing) — "
|
||||
"nothing to train."
|
||||
)
|
||||
|
||||
loss_dict["loss"] = float(total.detach().item()) if total.dim() == 0 else float("nan")
|
||||
if reduction == "none":
|
||||
return total.expand(batch[OBS_LANGUAGE_TOKENS].shape[0]), loss_dict
|
||||
return total, loss_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text loss
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
||||
|
||||
Re-uses the same prefix-embedding path the flow head does:
|
||||
embed images + state + language tokens, run a forward pass,
|
||||
slice out the per-token logits at the supervised positions,
|
||||
compute CE.
|
||||
"""
|
||||
from torch.nn import functional as F # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, tokens, masks
|
||||
)
|
||||
# PaliGemma's text path: forward the prefix through the
|
||||
# backbone *without* the action expert. We piggy-back on the
|
||||
# existing PaliGemmaWithExpertModel.forward — it accepts a
|
||||
# list of expert inputs and returns parallel outputs.
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
|
||||
|
||||
# Logits over the vocab via the PaliGemma lm_head.
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(vlm_out.to(lm_head.weight.dtype))
|
||||
|
||||
# Shift for next-token prediction: predict token[i+1] from
|
||||
# hidden[i]. Both ``logits`` and ``text_labels`` are over the
|
||||
# same sequence length, so shift logits[:-1] vs labels[1:].
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = text_labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return loss
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# select_message — AR text generation at inference
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def select_message(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
min_new_tokens: int = 0,
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
) -> str:
|
||||
"""Generate text continuation from a multimodal prefix.
|
||||
|
||||
Mirrors ``SmolVLA2Policy.select_message`` so the same
|
||||
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
|
||||
can drive π0.5 v2 unchanged.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
special_ids: set[int] = set()
|
||||
try:
|
||||
for sid in (tokenizer.all_special_ids or []):
|
||||
if sid is not None:
|
||||
special_ids.add(int(sid))
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if eos_token_id is not None:
|
||||
special_ids.add(int(eos_token_id))
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, tokens, masks
|
||||
)
|
||||
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
current_embs = prefix_embs
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
generated: list[int] = []
|
||||
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
backbone = self.model.paligemma_with_expert
|
||||
lm_head = backbone.paligemma.lm_head
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
(vlm_out, _), _ = backbone.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
break
|
||||
last = vlm_out[:, -1:].to(lm_head.weight.dtype)
|
||||
logits_step = lm_head(last)[:, -1] # (B, V)
|
||||
if special_ids and len(generated) < min_new_tokens:
|
||||
for sid in special_ids:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
if not decoded and generated:
|
||||
try:
|
||||
self._last_select_message_debug = (
|
||||
f"raw_ids={generated[:16]} "
|
||||
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
self._last_select_message_debug = f"raw_ids={generated[:16]}"
|
||||
else:
|
||||
self._last_select_message_debug = ""
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor:
|
||||
if temperature <= 0.0:
|
||||
return logits.argmax(dim=-1)
|
||||
scaled = logits / max(temperature, 1e-6)
|
||||
probs = torch.softmax(scaled, dim=-1)
|
||||
if top_p < 1.0:
|
||||
sorted_p, sorted_ix = torch.sort(probs, descending=True, dim=-1)
|
||||
cum = torch.cumsum(sorted_p, dim=-1)
|
||||
mask = cum > top_p
|
||||
mask[..., 0] = False
|
||||
sorted_p = sorted_p.masked_fill(mask, 0.0)
|
||||
sorted_p = sorted_p / sorted_p.sum(dim=-1, keepdim=True).clamp_min(1e-8)
|
||||
choice = torch.multinomial(sorted_p, num_samples=1)
|
||||
return sorted_ix.gather(-1, choice).squeeze(-1)
|
||||
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 pre/post-processor factory.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
relative-action prep (inherited from π0.5)
|
||||
NormalizerProcessorStep
|
||||
RenderMessagesStep — recipe → messages, target_message_indices,
|
||||
message_streams (PR 1 of the steerable
|
||||
stack)
|
||||
PI052TextTokenizerStep — messages → input_ids + label mask +
|
||||
predict_actions
|
||||
DeviceProcessorStep
|
||||
|
||||
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
|
||||
so unannotated datasets keep working.
|
||||
|
||||
Post-processor is unchanged from π0.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
RenderMessagesStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
|
||||
def make_pi052_pre_post_processors(
|
||||
config: PI052Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build PI0.5-v2's pre/post-processor pipelines.
|
||||
|
||||
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
PI052TextTokenizerStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AbsoluteActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
relative_step=relative_step,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/`` (same lookup rules as
|
||||
``make_smolvla2_pre_post_processors``).
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 text-tokenisation step.
|
||||
|
||||
PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on
|
||||
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
|
||||
messages as plain text with simple ``User: ... Assistant: ...`` role
|
||||
delimiters — matching the prompt format π0.5 uses in the paper
|
||||
(``Task: ... State: ... Action: ...``).
|
||||
|
||||
Outputs:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the
|
||||
concatenated prompt tokenised by the PaliGemma tokenizer (the same
|
||||
one ``processor_pi05`` already uses).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
positions belonging to messages whose index is in
|
||||
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
|
||||
those positions via the PaliGemma ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``. Same
|
||||
semantics as the SmolVLA2 step.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalise a message's content to a plain string.
|
||||
|
||||
The recipe renderer can emit ``content`` as a string OR as a list
|
||||
of HF-style multimodal blocks (``{type: text, text: ...}``,
|
||||
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
|
||||
only consume strings, so we flatten: drop image blocks (cameras
|
||||
flow through ``observation.images.*`` separately) and join text
|
||||
block texts.
|
||||
"""
|
||||
new = dict(message)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
content = new.get("content")
|
||||
if content is None:
|
||||
new["content"] = ""
|
||||
elif isinstance(content, str):
|
||||
pass
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
t = block.get("text", "")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
new["content"] = "\n".join(parts)
|
||||
else:
|
||||
new["content"] = str(content)
|
||||
return new
|
||||
|
||||
|
||||
def _format_messages(messages: list[dict[str, Any]]) -> tuple[str, list[tuple[int, int]]]:
|
||||
"""Concatenate messages into the π0.5-style flat prompt.
|
||||
|
||||
Returns:
|
||||
prompt: the full text the tokenizer will consume.
|
||||
msg_spans: list of ``(char_start, char_end)`` covering each
|
||||
message's content within ``prompt``. The
|
||||
target-mask builder uses this to find the
|
||||
character ranges belonging to the supervised
|
||||
messages.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
spans: list[tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for m in messages:
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "") or ""
|
||||
# Role tag + newline. The model has to learn to emit the same
|
||||
# role tokens at generation time, which is fine for greedy
|
||||
# decoding because the chat template is implicit in the
|
||||
# supervised target span.
|
||||
header = f"{role.capitalize()}: "
|
||||
# span covers ONLY the content portion (so labels are computed
|
||||
# over the supervised payload, not the role tag).
|
||||
full = header + content + "\n"
|
||||
start = cursor + len(header)
|
||||
end = start + len(content)
|
||||
parts.append(full)
|
||||
spans.append((start, end))
|
||||
cursor += len(full)
|
||||
return "".join(parts), spans
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
|
||||
class PI052TextTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template;
|
||||
concatenates messages as ``User: ... \\nAssistant: ...`` text.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_length: int = 200
|
||||
padding: str = "max_length"
|
||||
padding_side: str = "right"
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._tokenizer: Any = None
|
||||
|
||||
def _ensure_tokenizer(self) -> Any:
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
return self._tokenizer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline step
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
transition = transition.copy()
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
messages = complementary.get("messages") or []
|
||||
target_indices = list(complementary.get("target_message_indices") or [])
|
||||
message_streams = list(complementary.get("message_streams") or [])
|
||||
|
||||
if not messages:
|
||||
# No recipe was rendered — caller will fall back to the
|
||||
# plain Pi0.5 prompt path. We pass the transition through
|
||||
# unmodified.
|
||||
return transition
|
||||
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
# after removal.
|
||||
if (
|
||||
self.plan_dropout_prob
|
||||
or self.memory_dropout_prob
|
||||
or self.subtask_dropout_prob
|
||||
or self.interjection_dropout_prob
|
||||
):
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages,
|
||||
target_indices,
|
||||
complementary,
|
||||
)
|
||||
|
||||
messages = [_strip_blocks(m) for m in messages]
|
||||
prompt, spans = _format_messages(messages)
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
encoded = tokenizer(
|
||||
prompt,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_offsets_mapping=True,
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"][0]
|
||||
attention_mask = encoded["attention_mask"][0].bool()
|
||||
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
|
||||
|
||||
# Build label mask: -100 everywhere except over supervised
|
||||
# target message char ranges.
|
||||
labels = torch.full_like(input_ids, fill_value=-100)
|
||||
for idx in target_indices:
|
||||
if idx >= len(spans):
|
||||
continue
|
||||
char_start, char_end = spans[idx]
|
||||
for token_pos in range(input_ids.shape[0]):
|
||||
if not attention_mask[token_pos]:
|
||||
continue
|
||||
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
|
||||
if tok_end <= char_start or tok_start >= char_end:
|
||||
continue
|
||||
labels[token_pos] = input_ids[token_pos]
|
||||
|
||||
predict_actions = torch.tensor(
|
||||
bool(any(message_streams[i] == "low_level" for i in target_indices if i < len(message_streams))),
|
||||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
||||
**complementary,
|
||||
"text_labels": labels.unsqueeze(0),
|
||||
"predict_actions": predict_actions.unsqueeze(0),
|
||||
}
|
||||
return transition
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-component prompt dropout (Pi0.7 §V.E)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Drop messages classified as plan/memory/subtask context.
|
||||
|
||||
Targets are *never* dropped (they're the supervised payload).
|
||||
Re-maps target_indices to the new positions after drops.
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed = self.dropout_seed
|
||||
if seed is None:
|
||||
seed_src = complementary.get("dataset_index") or complementary.get("frame_index") or 0
|
||||
try:
|
||||
seed = int(seed_src)
|
||||
except (TypeError, ValueError):
|
||||
seed = 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
keep_indices: list[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if idx in target_indices:
|
||||
keep_indices.append(idx)
|
||||
continue
|
||||
kind = _classify_for_dropout(msg)
|
||||
prob = {
|
||||
"plan": self.plan_dropout_prob,
|
||||
"memory": self.memory_dropout_prob,
|
||||
"subtask": self.subtask_dropout_prob,
|
||||
"interjection": self.interjection_dropout_prob,
|
||||
}.get(kind, 0.0)
|
||||
if prob > 0.0 and rng.random() < prob:
|
||||
continue
|
||||
keep_indices.append(idx)
|
||||
|
||||
# Build remap and apply
|
||||
new_messages = [messages[i] for i in keep_indices]
|
||||
old_to_new = {old: new for new, old in enumerate(keep_indices)}
|
||||
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
|
||||
return new_messages, new_targets
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Heuristic content-prefix classifier — mirrors SmolVLA2's."""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
content = " ".join(text_parts)
|
||||
elif content is None:
|
||||
return None
|
||||
elif not isinstance(content, str):
|
||||
return None
|
||||
s = content.strip()
|
||||
if s.startswith("Plan:") or s.startswith("Previous plan"):
|
||||
return "plan"
|
||||
if s.startswith("Memory:") or s.startswith("Previous memory"):
|
||||
return "memory"
|
||||
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
Reference in New Issue
Block a user