feat(runtime): make the interactive runtime drive PI052 too

The runtime's text path was hard-wired to SmolVLA2: _build_text_batch
read policy.config.vlm_model_name (which PI052Config doesn't have) and
built a SmolVLM2 chat-template prompt. PI052/PaliGemma is not
chat-pretrained and trains on a flat `User: ... \nAssistant: ...`
prompt, so the runtime crashed or fed an out-of-distribution prefix.

- _build_text_batch now dispatches on policy.config.type: smolvla2 ->
  chat template (renamed _build_text_batch_chat); pi052 -> flat
  role-prefixed text via PI052TextTokenizerStep's own _format_messages /
  _strip_blocks / _flatten_say_tool_calls, so the inference prefix
  matches PI052 training exactly.
- Add a lerobot-pi052-runtime entry point (alias of the same main; the
  policy type is read from the checkpoint) so the command name isn't
  misleading. argparse prog now defaults to the invoked command name.

PI052's select_message / predict_action_chunk already work with the
runtime; this was the one SmolVLA2-only coupling.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 14:28:55 +02:00
parent 7b64e5498d
commit 725ac95b0d
3 changed files with 91 additions and 8 deletions
+5
View File
@@ -308,7 +308,12 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
# Interactive hierarchical-VLA runtime. The same entry point drives both
# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the
# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is
# an alias so the command name isn't misleading for PI052 users.
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
@@ -253,13 +253,85 @@ def _build_text_batch(
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""Tokenize a list of chat messages into the batch shape
``select_message`` expects.
"""Tokenize chat messages into the batch ``select_message`` expects.
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
on the chat tokenizer step. Production use should construct the
batch from a real observation; here we focus on the *language*
path which is independent of camera observations.
Dispatches on the policy backbone so one runtime drives both:
* ``smolvla2`` (SmolVLM2) chat template via ``apply_chat_template``.
* ``pi052`` (PaliGemma) flat ``Role: content`` text, since
PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``).
"""
if getattr(getattr(policy, "config", None), "type", "") == "pi052":
return _build_text_batch_pi052(
policy, prompt_messages, add_generation_prompt=add_generation_prompt
)
return _build_text_batch_chat(
policy, prompt_messages, add_generation_prompt=add_generation_prompt
)
def _build_text_batch_pi052(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""PI052 text batch — flat ``User: … \\nAssistant: …`` prompt.
PaliGemma ships no chat template, so PI052 trains on the plain
role-prefixed concatenation built by ``PI052TextTokenizerStep``.
Reuses that exact formatter so the inference prefix matches
training. ``add_generation_prompt`` appends the bare ``Assistant: ``
header the LM head continues from.
"""
import torch # noqa: PLC0415
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
_flatten_say_tool_calls,
_format_messages,
_strip_blocks,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = AutoTokenizer.from_pretrained(tok_name)
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
if add_generation_prompt:
prompt = prompt + "Assistant: "
encoded = tokenizer(prompt, return_tensors="pt")
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
attn = attn.bool()
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _build_text_batch_chat(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""SmolVLA2 (SmolVLM2) text batch — chat-template tokenization.
Reuses ``_strip_lerobot_blocks`` so the inference prompt shape
matches the training-time chat tokenizer step exactly.
"""
from transformers import AutoTokenizer # noqa: PLC0415
@@ -67,8 +67,14 @@ logger = logging.getLogger("lerobot.smolvla2.runtime")
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog="lerobot-smolvla2-runtime",
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
# prog defaults to the invoked command name, so this reads
# correctly whether run as lerobot-smolvla2-runtime or
# lerobot-pi052-runtime.
description=(
"Interactive REPL runtime for a trained hierarchical VLA "
"checkpoint (SmolVLA2 or PI052 — policy type is read from "
"the checkpoint)."
),
)
p.add_argument(
"--policy.path",