mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
feat(runtime): make the interactive runtime drive PI052 too
The runtime's text path was hard-wired to SmolVLA2: _build_text_batch read policy.config.vlm_model_name (which PI052Config doesn't have) and built a SmolVLM2 chat-template prompt. PI052/PaliGemma is not chat-pretrained and trains on a flat `User: ... \nAssistant: ...` prompt, so the runtime crashed or fed an out-of-distribution prefix. - _build_text_batch now dispatches on policy.config.type: smolvla2 -> chat template (renamed _build_text_batch_chat); pi052 -> flat role-prefixed text via PI052TextTokenizerStep's own _format_messages / _strip_blocks / _flatten_say_tool_calls, so the inference prefix matches PI052 training exactly. - Add a lerobot-pi052-runtime entry point (alias of the same main; the policy type is read from the checkpoint) so the command name isn't misleading. argparse prog now defaults to the invoked command name. PI052's select_message / predict_action_chunk already work with the runtime; this was the one SmolVLA2-only coupling. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -308,7 +308,12 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
# Interactive hierarchical-VLA runtime. The same entry point drives both
|
||||
# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the
|
||||
# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is
|
||||
# an alias so the command name isn't misleading for PI052 users.
|
||||
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -253,13 +253,85 @@ def _build_text_batch(
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Tokenize a list of chat messages into the batch shape
|
||||
``select_message`` expects.
|
||||
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||
|
||||
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
|
||||
on the chat tokenizer step. Production use should construct the
|
||||
batch from a real observation; here we focus on the *language*
|
||||
path which is independent of camera observations.
|
||||
Dispatches on the policy backbone so one runtime drives both:
|
||||
|
||||
* ``smolvla2`` (SmolVLM2) — chat template via ``apply_chat_template``.
|
||||
* ``pi052`` (PaliGemma) — flat ``Role: content`` text, since
|
||||
PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``).
|
||||
"""
|
||||
if getattr(getattr(policy, "config", None), "type", "") == "pi052":
|
||||
return _build_text_batch_pi052(
|
||||
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return _build_text_batch_chat(
|
||||
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
|
||||
|
||||
def _build_text_batch_pi052(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""PI052 text batch — flat ``User: … \\nAssistant: …`` prompt.
|
||||
|
||||
PaliGemma ships no chat template, so PI052 trains on the plain
|
||||
role-prefixed concatenation built by ``PI052TextTokenizerStep``.
|
||||
Reuses that exact formatter so the inference prefix matches
|
||||
training. ``add_generation_prompt`` appends the bare ``Assistant: ``
|
||||
header the LM head continues from.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
_strip_blocks,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
if add_generation_prompt:
|
||||
prompt = prompt + "Assistant: "
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||
attn = attn.bool()
|
||||
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _build_text_batch_chat(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""SmolVLA2 (SmolVLM2) text batch — chat-template tokenization.
|
||||
|
||||
Reuses ``_strip_lerobot_blocks`` so the inference prompt shape
|
||||
matches the training-time chat tokenizer step exactly.
|
||||
"""
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
|
||||
@@ -67,8 +67,14 @@ logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="lerobot-smolvla2-runtime",
|
||||
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
|
||||
# prog defaults to the invoked command name, so this reads
|
||||
# correctly whether run as lerobot-smolvla2-runtime or
|
||||
# lerobot-pi052-runtime.
|
||||
description=(
|
||||
"Interactive REPL runtime for a trained hierarchical VLA "
|
||||
"checkpoint (SmolVLA2 or PI052 — policy type is read from "
|
||||
"the checkpoint)."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--policy.path",
|
||||
|
||||
Reference in New Issue
Block a user