mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
feat(pi052): condition low-level prompt on state + fix eval slowdown
- Inject discretized proprioceptive state (256 bins, pi05 format) into low-level action-conditioning prompts in both training (PI052TextTokenizerStep) and eval (_with_low_level_subtask_prompt), matching the recipe's documented "[images, subtask, state]" intent. Higher-level subtask/memory text streams stay state-free. - Cache the loc-token tokenizer (_get_loc_tokenizer) instead of reloading it from disk on every _build_text_batch/select_message call (it ran twice per env per replan and dominated eval runtime). - Add a KV cache to select_message decode (bit-identical output to the recompute path) to avoid O(n^2) generation. Net: pi052 eval ~2.9 s/it -> ~0.1 s/it (~25x). Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -248,6 +248,22 @@ class DispatchAction(InferenceStep):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
|
||||
"""Return a loc-token-registered tokenizer, loading from disk only once.
|
||||
|
||||
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
|
||||
the result is immutable, so cache per ``tok_name``.
|
||||
"""
|
||||
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
|
||||
if tokenizer is None:
|
||||
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
|
||||
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _build_text_batch(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
@@ -277,7 +293,10 @@ def _build_text_batch(
|
||||
)
|
||||
# Register PaliGemma's <locDDDD> tokens so inference encoding /
|
||||
# decoding sees them as single vocab ids — must match training.
|
||||
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
|
||||
# The tokenizer is read-only after registration, so cache it: rebuilding it
|
||||
# from disk on every call dominated eval runtime (this runs twice per env
|
||||
# per replan — subtask gen + action prompt).
|
||||
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
|
||||
@@ -55,6 +55,7 @@ from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
from lerobot.utils.import_utils import require_package
|
||||
@@ -1592,6 +1593,7 @@ class PI052Policy(PreTrainedPolicy):
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
suppress_loc_tokens: bool = False,
|
||||
use_kv_cache: bool = True,
|
||||
) -> str:
|
||||
"""Generate text continuation from a multimodal prefix.
|
||||
|
||||
@@ -1613,11 +1615,11 @@ class PI052Policy(PreTrainedPolicy):
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from .inference.steps import _get_loc_tokenizer # noqa: PLC0415
|
||||
from .text_processor_pi052 import register_paligemma_loc_tokens # noqa: PLC0415
|
||||
|
||||
tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
)
|
||||
tok_name = getattr(self.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
@@ -1647,7 +1649,16 @@ class PI052Policy(PreTrainedPolicy):
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
generated: list[int] = []
|
||||
new_emb = None
|
||||
|
||||
# KV-cache decode: encode the (image-heavy) prefix once, then feed only
|
||||
# the newly sampled token each step, attending to the cached keys. This
|
||||
# turns an O(n_tokens * prefix_len) recompute into O(prefix_len + n_tokens)
|
||||
# and is the dominant cost here (the prefix carries ~3*256 image tokens).
|
||||
# With ``use_kv_cache=False`` the loop reduces exactly to the original
|
||||
# recompute path (cache stays ``None`` so every step re-runs the full
|
||||
# prefix), which we keep as a fallback / parity reference.
|
||||
cache = None
|
||||
|
||||
backbone = self.model.paligemma_with_expert
|
||||
lm_head = backbone.paligemma.lm_head
|
||||
@@ -1666,16 +1677,32 @@ class PI052Policy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
|
||||
(vlm_out, _), _ = backbone.forward(
|
||||
if cache is None:
|
||||
# First step (and every step when caching is disabled): run the
|
||||
# full bidirectional-prefix forward. ``current_*`` already grow
|
||||
# in the no-cache fallback below.
|
||||
step_embs = current_embs
|
||||
att_2d = make_att_2d_masks(current_pad, current_att)
|
||||
position_ids = torch.cumsum(current_pad, dim=1) - 1
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
|
||||
else:
|
||||
# Incremental step: only the last token. It attends to every
|
||||
# valid cached key (``current_pad`` already includes this token),
|
||||
# so pad positions in the prefix stay masked just like the
|
||||
# recompute path.
|
||||
step_embs = new_emb
|
||||
att_2d = current_pad[:, None, :]
|
||||
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d, dtype=backbone_dtype)
|
||||
position_ids = (torch.cumsum(current_pad, dim=1) - 1)[:, -1:]
|
||||
(vlm_out, _), new_cache = backbone.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
past_key_values=cache,
|
||||
inputs_embeds=[step_embs, None],
|
||||
use_cache=use_kv_cache,
|
||||
)
|
||||
if use_kv_cache:
|
||||
cache = new_cache
|
||||
if vlm_out is None:
|
||||
break
|
||||
last = vlm_out[:, -1:].to(lm_head.weight.dtype)
|
||||
@@ -1702,9 +1729,13 @@ class PI052Policy(PreTrainedPolicy):
|
||||
|
||||
# embed_language_tokens already applies the Gemma sqrt(hidden) scale (tf>=5.4.0).
|
||||
new_emb = backbone.embed_language_tokens(next_ids.unsqueeze(0))
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
# ``current_pad`` tracks valid keys for both paths (cache mask +
|
||||
# position ids). Only the recompute path needs the full embedding /
|
||||
# block-attention history re-fed each step.
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
if not use_kv_cache:
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
if not decoded and generated:
|
||||
@@ -1744,10 +1775,15 @@ class PI052Policy(PreTrainedPolicy):
|
||||
|
||||
def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
from .inference.steps import _build_text_batch # noqa: PLC0415
|
||||
from .text_processor_pi052 import discretize_state_str # noqa: PLC0415
|
||||
|
||||
n = self._batch_size_from_observation(batch)
|
||||
self._ensure_subtask_state(n)
|
||||
tasks = self._tasks_from_batch(batch, n)
|
||||
# Normalized state for the low-level action prompt (mirrors training:
|
||||
# "User: {subtask}, State: {256-bin};"). batch state is already
|
||||
# normalized by the eval preprocessor's NormalizerProcessorStep.
|
||||
state_all = batch.get(OBS_STATE)
|
||||
|
||||
# Generate one subtask per parallel env, each conditioned on that env's
|
||||
# own task + observation, then stack the per-env prompts into a single
|
||||
@@ -1758,9 +1794,12 @@ class PI052Policy(PreTrainedPolicy):
|
||||
for i in range(n):
|
||||
obs_i = self._slice_observation(batch, i)
|
||||
subtask = self._generate_low_level_subtask(obs_i, tasks[i], i)
|
||||
content = subtask
|
||||
if torch.is_tensor(state_all):
|
||||
content = f"{subtask}, State: {discretize_state_str(state_all[i])};"
|
||||
text_batch = _build_text_batch(
|
||||
self,
|
||||
[{"role": "user", "content": subtask}],
|
||||
[{"role": "user", "content": content}],
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
rows.append((text_batch["lang_tokens"], text_batch["lang_masks"]))
|
||||
|
||||
@@ -40,17 +40,40 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discretize_state_str(state_row: Any) -> str:
|
||||
"""Discretize a single normalized state vector into 256 bins, space-joined.
|
||||
|
||||
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
|
||||
convention) so pi052's low-level action prompt carries proprioception in
|
||||
the exact format pi05 was trained on. Expects state already normalized by
|
||||
the upstream ``NormalizerProcessorStep``.
|
||||
"""
|
||||
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
|
||||
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
|
||||
|
||||
|
||||
def _state_row_at(state_all: Any, pos: int) -> Any:
|
||||
"""Select the per-sample state row from a (possibly batched) state tensor."""
|
||||
if state_all is None:
|
||||
return None
|
||||
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
|
||||
return state_all[pos]
|
||||
return state_all
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
|
||||
if isinstance(content, str):
|
||||
@@ -391,6 +414,10 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
|
||||
# runs before this step). Injected into low-level action prompts so the
|
||||
# action expert sees proprioception, matching pi05's discretized State:.
|
||||
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
|
||||
# VQA coords are 0–1000 normalized (Qwen2.5-VL convention) — the
|
||||
# <loc> conversion is camera-resolution-independent and needs no
|
||||
# observation lookup here.
|
||||
@@ -404,13 +431,16 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
list(tgt_indices),
|
||||
complementary,
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
state_row=_state_row_at(state_all, pos),
|
||||
)
|
||||
for msg, streams, tgt_indices, s_idx in zip(
|
||||
messages,
|
||||
complementary.get("message_streams") or [[] for _ in messages],
|
||||
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
|
||||
zip(
|
||||
messages,
|
||||
complementary.get("message_streams") or [[] for _ in messages],
|
||||
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
@@ -423,6 +453,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
list(complementary.get("target_message_indices") or []),
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
state_row=_state_row_at(state_all, 0),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -446,6 +477,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
state_row: Any = None,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
@@ -472,6 +504,17 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
# stripping, so the spoken reply is actually tokenized and
|
||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||
# Low-level (action-conditioning) samples get the discretized state
|
||||
# appended to their user message, mirroring pi05's
|
||||
# "..., State: {256-bin};" so the action expert sees proprioception.
|
||||
# Higher-level text streams (subtask/memory generation) stay state-free.
|
||||
if state_row is not None and any(s == "low_level" for s in message_streams):
|
||||
state_str = discretize_state_str(state_row)
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
base = _content_to_text(m.get("content", ""))
|
||||
m["content"] = f"{base}, State: {state_str};"
|
||||
break
|
||||
# Append EOS to supervised target turns so the LM head learns to
|
||||
# stop (the span covers it → it becomes a supervised label).
|
||||
prompt, spans = _format_messages(
|
||||
|
||||
Reference in New Issue
Block a user