From 725ac95b0d1c894c645243d541d59e934e26afd6 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 19 May 2026 14:28:55 +0200 Subject: [PATCH] feat(runtime): make the interactive runtime drive PI052 too The runtime's text path was hard-wired to SmolVLA2: _build_text_batch read policy.config.vlm_model_name (which PI052Config doesn't have) and built a SmolVLM2 chat-template prompt. PI052/PaliGemma is not chat-pretrained and trains on a flat `User: ... \nAssistant: ...` prompt, so the runtime crashed or fed an out-of-distribution prefix. - _build_text_batch now dispatches on policy.config.type: smolvla2 -> chat template (renamed _build_text_batch_chat); pi052 -> flat role-prefixed text via PI052TextTokenizerStep's own _format_messages / _strip_blocks / _flatten_say_tool_calls, so the inference prefix matches PI052 training exactly. - Add a lerobot-pi052-runtime entry point (alias of the same main; the policy type is read from the checkpoint) so the command name isn't misleading. argparse prog now defaults to the invoked command name. PI052's select_message / predict_action_chunk already work with the runtime; this was the one SmolVLA2-only coupling. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 5 ++ .../policies/smolvla2/inference/steps.py | 84 +++++++++++++++++-- .../scripts/lerobot_smolvla2_runtime.py | 10 ++- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 360a00e1b..0a7ced215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -308,7 +308,12 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-annotate="lerobot.scripts.lerobot_annotate:main" +# Interactive hierarchical-VLA runtime. The same entry point drives both +# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the +# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is +# an alias so the command name isn't misleading for PI052 users. lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main" +lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.package-data] diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 93417b824..403e10003 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -253,13 +253,85 @@ def _build_text_batch( *, add_generation_prompt: bool = True, ) -> dict[str, Any]: - """Tokenize a list of chat messages into the batch shape - ``select_message`` expects. + """Tokenize chat messages into the batch ``select_message`` expects. - Lazy fallback: re-uses the policy's preprocessor by piggy-backing - on the chat tokenizer step. Production use should construct the - batch from a real observation; here we focus on the *language* - path which is independent of camera observations. + Dispatches on the policy backbone so one runtime drives both: + + * ``smolvla2`` (SmolVLM2) — chat template via ``apply_chat_template``. + * ``pi052`` (PaliGemma) — flat ``Role: content`` text, since + PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``). + """ + if getattr(getattr(policy, "config", None), "type", "") == "pi052": + return _build_text_batch_pi052( + policy, prompt_messages, add_generation_prompt=add_generation_prompt + ) + return _build_text_batch_chat( + policy, prompt_messages, add_generation_prompt=add_generation_prompt + ) + + +def _build_text_batch_pi052( + policy: Any, + prompt_messages: list[dict[str, Any]], + *, + add_generation_prompt: bool = True, +) -> dict[str, Any]: + """PI052 text batch — flat ``User: … \\nAssistant: …`` prompt. + + PaliGemma ships no chat template, so PI052 trains on the plain + role-prefixed concatenation built by ``PI052TextTokenizerStep``. + Reuses that exact formatter so the inference prefix matches + training. ``add_generation_prompt`` appends the bare ``Assistant: `` + header the LM head continues from. + """ + import torch # noqa: PLC0415 + from transformers import AutoTokenizer # noqa: PLC0415 + + from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415 + _flatten_say_tool_calls, + _format_messages, + _strip_blocks, + ) + + tok_name = ( + getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224" + ) + tokenizer = AutoTokenizer.from_pretrained(tok_name) + + messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages] + prompt, _spans = _format_messages(messages) + if add_generation_prompt: + prompt = prompt + "Assistant: " + + encoded = tokenizer(prompt, return_tensors="pt") + ids = encoded["input_ids"] + attn = encoded.get("attention_mask") + if attn is None and tokenizer.pad_token_id is not None: + attn = ids != tokenizer.pad_token_id + if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool: + attn = attn.bool() + + device = getattr(getattr(policy, "config", None), "device", None) + if device is not None: + try: + ids = ids.to(device) + if attn is not None and hasattr(attn, "to"): + attn = attn.to(device) + except Exception as exc: # noqa: BLE001 + logger.debug("could not move pi052 lang tokens to %s: %s", device, exc) + return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer} + + +def _build_text_batch_chat( + policy: Any, + prompt_messages: list[dict[str, Any]], + *, + add_generation_prompt: bool = True, +) -> dict[str, Any]: + """SmolVLA2 (SmolVLM2) text batch — chat-template tokenization. + + Reuses ``_strip_lerobot_blocks`` so the inference prompt shape + matches the training-time chat tokenizer step exactly. """ from transformers import AutoTokenizer # noqa: PLC0415 diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index 645985296..cc5097a06 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -67,8 +67,14 @@ logger = logging.getLogger("lerobot.smolvla2.runtime") def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: p = argparse.ArgumentParser( - prog="lerobot-smolvla2-runtime", - description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.", + # prog defaults to the invoked command name, so this reads + # correctly whether run as lerobot-smolvla2-runtime or + # lerobot-pi052-runtime. + description=( + "Interactive REPL runtime for a trained hierarchical VLA " + "checkpoint (SmolVLA2 or PI052 — policy type is read from " + "the checkpoint)." + ), ) p.add_argument( "--policy.path",