feat(pi052): condition low-level prompt on state + fix eval slowdown

- Inject discretized proprioceptive state (256 bins, pi05 format) into
  low-level action-conditioning prompts in both training
  (PI052TextTokenizerStep) and eval (_with_low_level_subtask_prompt),
  matching the recipe's documented "[images, subtask, state]" intent.
  Higher-level subtask/memory text streams stay state-free.
- Cache the loc-token tokenizer (_get_loc_tokenizer) instead of reloading
  it from disk on every _build_text_batch/select_message call (it ran
  twice per env per replan and dominated eval runtime).
- Add a KV cache to select_message decode (bit-identical output to the
  recompute path) to avoid O(n^2) generation.

Net: pi052 eval ~2.9 s/it -> ~0.1 s/it (~25x).
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-14 13:57:30 +02:00
parent c5965d4971
commit 3427499212
3 changed files with 122 additions and 21 deletions
+20 -1
View File
@@ -248,6 +248,22 @@ class DispatchAction(InferenceStep):
# ---------------------------------------------------------------------------
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
"""Return a loc-token-registered tokenizer, loading from disk only once.
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
the result is immutable, so cache per ``tok_name``.
"""
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
if tokenizer is None:
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
return tokenizer
def _build_text_batch(
policy: Any,
prompt_messages: list[dict[str, Any]],
@@ -277,7 +293,10 @@ def _build_text_batch(
)
# Register PaliGemma's <locDDDD> tokens so inference encoding /
# decoding sees them as single vocab ids — must match training.
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
# The tokenizer is read-only after registration, so cache it: rebuilding it
# from disk on every call dominated eval runtime (this runs twice per env
# per replan — subtask gen + action prompt).
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
+52 -13
View File
@@ -55,6 +55,7 @@ from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
from lerobot.utils.import_utils import require_package
@@ -1592,6 +1593,7 @@ class PI052Policy(PreTrainedPolicy):
top_p: float = 1.0,
tokenizer: Any = None,
suppress_loc_tokens: bool = False,
use_kv_cache: bool = True,
) -> str:
"""Generate text continuation from a multimodal prefix.
@@ -1613,11 +1615,11 @@ class PI052Policy(PreTrainedPolicy):
if tokenizer is None:
from transformers import AutoTokenizer # noqa: PLC0415
from .inference.steps import _get_loc_tokenizer # noqa: PLC0415
from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415
tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
)
tok_name = getattr(self.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
@@ -1647,7 +1649,16 @@ class PI052Policy(PreTrainedPolicy):
current_pad = prefix_pad_masks
current_att = prefix_att_masks
generated: list[int] = []
new_emb = None
# KV-cache decode: encode the (image-heavy) prefix once, then feed only
# the newly sampled token each step, attending to the cached keys. This
# turns an O(n_tokens * prefix_len) recompute into O(prefix_len + n_tokens)
# and is the dominant cost here (the prefix carries ~3*256 image tokens).
# With ``use_kv_cache=False`` the loop reduces exactly to the original
# recompute path (cache stays ``None`` so every step re-runs the full
# prefix), which we keep as a fallback / parity reference.
cache = None
backbone = self.model.paligemma_with_expert
lm_head = backbone.paligemma.lm_head
@@ -1666,16 +1677,32 @@ class PI052Policy(PreTrainedPolicy):
)
for _ in range(max_new_tokens):
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
(vlm_out, _), _ = backbone.forward(
if cache is None:
# First step (and every step when caching is disabled): run the
# full bidirectional-prefix forward. ``current_*`` already grow
# in the no-cache fallback below.
step_embs = current_embs
att_2d = make_att_2d_masks(current_pad, current_att)
position_ids = torch.cumsum(current_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
else:
# Incremental step: only the last token. It attends to every
# valid cached key (``current_pad`` already includes this token),
# so pad positions in the prefix stay masked just like the
# recompute path.
step_embs = new_emb
att_2d = current_pad[:, None, :]
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
position_ids = (torch.cumsum(current_pad, dim=1) - 1)[:, -1:]
(vlm_out, _), new_cache = backbone.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[current_embs, None],
use_cache=False,
past_key_values=cache,
inputs_embeds=[step_embs, None],
use_cache=use_kv_cache,
)
if use_kv_cache:
cache = new_cache
if vlm_out is None:
break
last = vlm_out[:, -1:].to(lm_head.weight.dtype)
@@ -1702,9 +1729,13 @@ class PI052Policy(PreTrainedPolicy):
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0).
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
current_embs = torch.cat([current_embs, new_emb], dim=1)
# ``current_pad`` tracks valid keys for both paths (cache mask +
# position ids). Only the recompute path needs the full embedding /
# block-attention history re-fed each step.
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
if not use_kv_cache:
current_embs = torch.cat([current_embs, new_emb], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
if not decoded and generated:
@@ -1744,10 +1775,15 @@ class PI052Policy(PreTrainedPolicy):
def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
from .inference.steps import _build_text_batch # noqa: PLC0415
from .text_processor_pi052 import discretize_state_str # noqa: PLC0415
n = self._batch_size_from_observation(batch)
self._ensure_subtask_state(n)
tasks = self._tasks_from_batch(batch, n)
# Normalized state for the low-level action prompt (mirrors training:
# "User: {subtask}, State: {256-bin};"). batch state is already
# normalized by the eval preprocessor's NormalizerProcessorStep.
state_all = batch.get(OBS_STATE)
# Generate one subtask per parallel env, each conditioned on that env's
# own task + observation, then stack the per-env prompts into a single
@@ -1758,9 +1794,12 @@ class PI052Policy(PreTrainedPolicy):
for i in range(n):
obs_i = self._slice_observation(batch, i)
subtask = self._generate_low_level_subtask(obs_i, tasks[i], i)
content = subtask
if torch.is_tensor(state_all):
content = f"{subtask}, State: {discretize_state_str(state_all[i])};"
text_batch = _build_text_batch(
self,
[{"role": "user", "content": subtask}],
[{"role": "user", "content": content}],
add_generation_prompt=False,
)
rows.append((text_batch["lang_tokens"], text_batch["lang_masks"]))
@@ -40,17 +40,40 @@ import logging
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
logger = logging.getLogger(__name__)
def discretize_state_str(state_row: Any) -> str:
"""Discretize a single normalized state vector into 256 bins, space-joined.
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
convention) so pi052's low-level action prompt carries proprioception in
the exact format pi05 was trained on. Expects state already normalized by
the upstream ``NormalizerProcessorStep``.
"""
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
def _state_row_at(state_all: Any, pos: int) -> Any:
"""Select the per-sample state row from a (possibly batched) state tensor."""
if state_all is None:
return None
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
return state_all[pos]
return state_all
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
if isinstance(content, str):
@@ -391,6 +414,10 @@ class PI052TextTokenizerStep(ProcessorStep):
return transition
tokenizer = self._ensure_tokenizer()
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
# runs before this step). Injected into low-level action prompts so the
# action expert sees proprioception, matching pi05's discretized State:.
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
# VQA coords are 01000 normalized (Qwen2.5-VL convention) — the
# <loc> conversion is camera-resolution-independent and needs no
# observation lookup here.
@@ -404,13 +431,16 @@ class PI052TextTokenizerStep(ProcessorStep):
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
state_row=_state_row_at(state_all, pos),
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
)
]
else:
@@ -423,6 +453,7 @@ class PI052TextTokenizerStep(ProcessorStep):
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
state_row=_state_row_at(state_all, 0),
)
]
@@ -446,6 +477,7 @@ class PI052TextTokenizerStep(ProcessorStep):
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
state_row: Any = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
@@ -472,6 +504,17 @@ class PI052TextTokenizerStep(ProcessorStep):
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
# Low-level (action-conditioning) samples get the discretized state
# appended to their user message, mirroring pi05's
# "..., State: {256-bin};" so the action expert sees proprioception.
# Higher-level text streams (subtask/memory generation) stay state-free.
if state_row is not None and any(s == "low_level" for s in message_streams):
state_str = discretize_state_str(state_row)
for m in reversed(messages):
if m.get("role") == "user":
base = _content_to_text(m.get("content", ""))
m["content"] = f"{base}, State: {state_str};"
break
# Append EOS to supervised target turns so the LM head learns to
# stop (the span covers it → it becomes a supervised label).
prompt, spans = _format_messages(