From 34274992120c5fbc640d418a17dbae67e5f00bf2 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Sun, 14 Jun 2026 13:57:30 +0200 Subject: [PATCH] feat(pi052): condition low-level prompt on state + fix eval slowdown - Inject discretized proprioceptive state (256 bins, pi05 format) into low-level action-conditioning prompts in both training (PI052TextTokenizerStep) and eval (_with_low_level_subtask_prompt), matching the recipe's documented "[images, subtask, state]" intent. Higher-level subtask/memory text streams stay state-free. - Cache the loc-token tokenizer (_get_loc_tokenizer) instead of reloading it from disk on every _build_text_batch/select_message call (it ran twice per env per replan and dominated eval runtime). - Add a KV cache to select_message decode (bit-identical output to the recompute path) to avoid O(n^2) generation. Net: pi052 eval ~2.9 s/it -> ~0.1 s/it (~25x). Co-authored-by: Cursor --- src/lerobot/policies/pi052/inference/steps.py | 21 +++++- src/lerobot/policies/pi052/modeling_pi052.py | 65 +++++++++++++++---- .../policies/pi052/text_processor_pi052.py | 57 ++++++++++++++-- 3 files changed, 122 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/pi052/inference/steps.py b/src/lerobot/policies/pi052/inference/steps.py index d205cc6e7..819f73cd4 100644 --- a/src/lerobot/policies/pi052/inference/steps.py +++ b/src/lerobot/policies/pi052/inference/steps.py @@ -248,6 +248,22 @@ class DispatchAction(InferenceStep): # --------------------------------------------------------------------------- +_LOC_TOKENIZER_CACHE: dict[str, Any] = {} + + +def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any: + """Return a loc-token-registered tokenizer, loading from disk only once. + + ``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and + the result is immutable, so cache per ``tok_name``. + """ + tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name) + if tokenizer is None: + tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name)) + _LOC_TOKENIZER_CACHE[tok_name] = tokenizer + return tokenizer + + def _build_text_batch( policy: Any, prompt_messages: list[dict[str, Any]], @@ -277,7 +293,10 @@ def _build_text_batch( ) # Register PaliGemma's tokens so inference encoding / # decoding sees them as single vocab ids — must match training. - tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name)) + # The tokenizer is read-only after registration, so cache it: rebuilding it + # from disk on every call dominated eval runtime (this runs twice per env + # per replan — subtask gen + action prompt). + tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens) messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages] prompt, _spans = _format_messages(messages) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 1283042d4..a87e9b329 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -55,6 +55,7 @@ from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, + OBS_STATE, OPENPI_ATTENTION_MASK_VALUE, ) from lerobot.utils.import_utils import require_package @@ -1592,6 +1593,7 @@ class PI052Policy(PreTrainedPolicy): top_p: float = 1.0, tokenizer: Any = None, suppress_loc_tokens: bool = False, + use_kv_cache: bool = True, ) -> str: """Generate text continuation from a multimodal prefix. @@ -1613,11 +1615,11 @@ class PI052Policy(PreTrainedPolicy): if tokenizer is None: from transformers import AutoTokenizer # noqa: PLC0415 + from .inference.steps import _get_loc_tokenizer # noqa: PLC0415 from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415 - tokenizer = register_paligemma_loc_tokens( - AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") - ) + tok_name = getattr(self.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224" + tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens) if eos_token_id is None: eos_token_id = tokenizer.eos_token_id @@ -1647,7 +1649,16 @@ class PI052Policy(PreTrainedPolicy): current_pad = prefix_pad_masks current_att = prefix_att_masks generated: list[int] = [] + new_emb = None + # KV-cache decode: encode the (image-heavy) prefix once, then feed only + # the newly sampled token each step, attending to the cached keys. This + # turns an O(n_tokens * prefix_len) recompute into O(prefix_len + n_tokens) + # and is the dominant cost here (the prefix carries ~3*256 image tokens). + # With ``use_kv_cache=False`` the loop reduces exactly to the original + # recompute path (cache stays ``None`` so every step re-runs the full + # prefix), which we keep as a fallback / parity reference. + cache = None backbone = self.model.paligemma_with_expert lm_head = backbone.paligemma.lm_head @@ -1666,16 +1677,32 @@ class PI052Policy(PreTrainedPolicy): ) for _ in range(max_new_tokens): - att_2d = make_att_2d_masks(current_pad, current_att) - position_ids = torch.cumsum(current_pad, dim=1) - 1 - att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype) - (vlm_out, _), _ = backbone.forward( + if cache is None: + # First step (and every step when caching is disabled): run the + # full bidirectional-prefix forward. ``current_*`` already grow + # in the no-cache fallback below. + step_embs = current_embs + att_2d = make_att_2d_masks(current_pad, current_att) + position_ids = torch.cumsum(current_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype) + else: + # Incremental step: only the last token. It attends to every + # valid cached key (``current_pad`` already includes this token), + # so pad positions in the prefix stay masked just like the + # recompute path. + step_embs = new_emb + att_2d = current_pad[:, None, :] + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype) + position_ids = (torch.cumsum(current_pad, dim=1) - 1)[:, -1:] + (vlm_out, _), new_cache = backbone.forward( attention_mask=att_2d_4d, position_ids=position_ids, - past_key_values=None, - inputs_embeds=[current_embs, None], - use_cache=False, + past_key_values=cache, + inputs_embeds=[step_embs, None], + use_cache=use_kv_cache, ) + if use_kv_cache: + cache = new_cache if vlm_out is None: break last = vlm_out[:, -1:].to(lm_head.weight.dtype) @@ -1702,9 +1729,13 @@ class PI052Policy(PreTrainedPolicy): # embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0). new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0)) - current_embs = torch.cat([current_embs, new_emb], dim=1) + # ``current_pad`` tracks valid keys for both paths (cache mask + + # position ids). Only the recompute path needs the full embedding / + # block-attention history re-fed each step. current_pad = torch.cat([current_pad, ones_step], dim=1) - current_att = torch.cat([current_att, ones_step], dim=1) + if not use_kv_cache: + current_embs = torch.cat([current_embs, new_emb], dim=1) + current_att = torch.cat([current_att, ones_step], dim=1) decoded = tokenizer.decode(generated, skip_special_tokens=True).strip() if not decoded and generated: @@ -1744,10 +1775,15 @@ class PI052Policy(PreTrainedPolicy): def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: from .inference.steps import _build_text_batch # noqa: PLC0415 + from .text_processor_pi052 import discretize_state_str # noqa: PLC0415 n = self._batch_size_from_observation(batch) self._ensure_subtask_state(n) tasks = self._tasks_from_batch(batch, n) + # Normalized state for the low-level action prompt (mirrors training: + # "User: {subtask}, State: {256-bin};"). batch state is already + # normalized by the eval preprocessor's NormalizerProcessorStep. + state_all = batch.get(OBS_STATE) # Generate one subtask per parallel env, each conditioned on that env's # own task + observation, then stack the per-env prompts into a single @@ -1758,9 +1794,12 @@ class PI052Policy(PreTrainedPolicy): for i in range(n): obs_i = self._slice_observation(batch, i) subtask = self._generate_low_level_subtask(obs_i, tasks[i], i) + content = subtask + if torch.is_tensor(state_all): + content = f"{subtask}, State: {discretize_state_str(state_all[i])};" text_batch = _build_text_batch( self, - [{"role": "user", "content": subtask}], + [{"role": "user", "content": content}], add_generation_prompt=False, ) rows.append((text_batch["lang_tokens"], text_batch["lang_masks"])) diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index f121db929..09175c824 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -40,17 +40,40 @@ import logging from dataclasses import dataclass from typing import Any +import numpy as np import torch from torch import Tensor from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry from lerobot.types import EnvTransition, TransitionKey -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE logger = logging.getLogger(__name__) +def discretize_state_str(state_row: Any) -> str: + """Discretize a single normalized state vector into 256 bins, space-joined. + + Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins / + convention) so pi052's low-level action prompt carries proprioception in + the exact format pi05 was trained on. Expects state already normalized by + the upstream ``NormalizerProcessorStep``. + """ + arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row) + disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + return " ".join(str(int(x)) for x in disc.reshape(-1).tolist()) + + +def _state_row_at(state_all: Any, pos: int) -> Any: + """Select the per-sample state row from a (possibly batched) state tensor.""" + if state_all is None: + return None + if hasattr(state_all, "ndim") and state_all.ndim >= 2: + return state_all[pos] + return state_all + + def _content_to_text(content: Any) -> str: """Collapse a message's ``content`` (string or multimodal blocks) to text.""" if isinstance(content, str): @@ -391,6 +414,10 @@ class PI052TextTokenizerStep(ProcessorStep): return transition tokenizer = self._ensure_tokenizer() + # Normalized proprioceptive state (set by NormalizerProcessorStep, which + # runs before this step). Injected into low-level action prompts so the + # action expert sees proprioception, matching pi05's discretized State:. + state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE) # VQA coords are 0–1000 normalized (Qwen2.5-VL convention) — the # conversion is camera-resolution-independent and needs no # observation lookup here. @@ -404,13 +431,16 @@ class PI052TextTokenizerStep(ProcessorStep): list(tgt_indices), complementary, sample_idx=int(s_idx) if s_idx is not None else None, + state_row=_state_row_at(state_all, pos), ) - for msg, streams, tgt_indices, s_idx in zip( - messages, - complementary.get("message_streams") or [[] for _ in messages], - complementary.get("target_message_indices") or [[] for _ in messages], - indices_iter, - strict=False, + for pos, (msg, streams, tgt_indices, s_idx) in enumerate( + zip( + messages, + complementary.get("message_streams") or [[] for _ in messages], + complementary.get("target_message_indices") or [[] for _ in messages], + indices_iter, + strict=False, + ) ) ] else: @@ -423,6 +453,7 @@ class PI052TextTokenizerStep(ProcessorStep): list(complementary.get("target_message_indices") or []), complementary, sample_idx=sample_idx, + state_row=_state_row_at(state_all, 0), ) ] @@ -446,6 +477,7 @@ class PI052TextTokenizerStep(ProcessorStep): target_indices: list[int], complementary: dict[str, Any], sample_idx: int | None = None, + state_row: Any = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor, str]: # Optional: drop non-target messages per the dropout config. # Keeps the supervised-target indices stable by re-mapping @@ -472,6 +504,17 @@ class PI052TextTokenizerStep(ProcessorStep): # stripping, so the spoken reply is actually tokenized and # supervised (PaliGemma's flat prompt has no structured calls). messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages] + # Low-level (action-conditioning) samples get the discretized state + # appended to their user message, mirroring pi05's + # "..., State: {256-bin};" so the action expert sees proprioception. + # Higher-level text streams (subtask/memory generation) stay state-free. + if state_row is not None and any(s == "low_level" for s in message_streams): + state_str = discretize_state_str(state_row) + for m in reversed(messages): + if m.get("role") == "user": + base = _content_to_text(m.get("content", "")) + m["content"] = f"{base}, State: {state_str};" + break # Append EOS to supervised target turns so the LM head learns to # stop (the span covers it → it becomes a supervised label). prompt, spans = _format_messages(