mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
feat(runtime): make the interactive runtime drive PI052 too
The runtime's text path was hard-wired to SmolVLA2: _build_text_batch read policy.config.vlm_model_name (which PI052Config doesn't have) and built a SmolVLM2 chat-template prompt. PI052/PaliGemma is not chat-pretrained and trains on a flat `User: ... \nAssistant: ...` prompt, so the runtime crashed or fed an out-of-distribution prefix. - _build_text_batch now dispatches on policy.config.type: smolvla2 -> chat template (renamed _build_text_batch_chat); pi052 -> flat role-prefixed text via PI052TextTokenizerStep's own _format_messages / _strip_blocks / _flatten_say_tool_calls, so the inference prefix matches PI052 training exactly. - Add a lerobot-pi052-runtime entry point (alias of the same main; the policy type is read from the checkpoint) so the command name isn't misleading. argparse prog now defaults to the invoked command name. PI052's select_message / predict_action_chunk already work with the runtime; this was the one SmolVLA2-only coupling. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -308,7 +308,12 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
|||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||||
|
# Interactive hierarchical-VLA runtime. The same entry point drives both
|
||||||
|
# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the
|
||||||
|
# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is
|
||||||
|
# an alias so the command name isn't misleading for PI052 users.
|
||||||
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||||
|
lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
|
|||||||
@@ -253,13 +253,85 @@ def _build_text_batch(
|
|||||||
*,
|
*,
|
||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Tokenize a list of chat messages into the batch shape
|
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||||
``select_message`` expects.
|
|
||||||
|
|
||||||
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
|
Dispatches on the policy backbone so one runtime drives both:
|
||||||
on the chat tokenizer step. Production use should construct the
|
|
||||||
batch from a real observation; here we focus on the *language*
|
* ``smolvla2`` (SmolVLM2) — chat template via ``apply_chat_template``.
|
||||||
path which is independent of camera observations.
|
* ``pi052`` (PaliGemma) — flat ``Role: content`` text, since
|
||||||
|
PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``).
|
||||||
|
"""
|
||||||
|
if getattr(getattr(policy, "config", None), "type", "") == "pi052":
|
||||||
|
return _build_text_batch_pi052(
|
||||||
|
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||||
|
)
|
||||||
|
return _build_text_batch_chat(
|
||||||
|
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_batch_pi052(
|
||||||
|
policy: Any,
|
||||||
|
prompt_messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""PI052 text batch — flat ``User: … \\nAssistant: …`` prompt.
|
||||||
|
|
||||||
|
PaliGemma ships no chat template, so PI052 trains on the plain
|
||||||
|
role-prefixed concatenation built by ``PI052TextTokenizerStep``.
|
||||||
|
Reuses that exact formatter so the inference prefix matches
|
||||||
|
training. ``add_generation_prompt`` appends the bare ``Assistant: ``
|
||||||
|
header the LM head continues from.
|
||||||
|
"""
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
from transformers import AutoTokenizer # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||||
|
_flatten_say_tool_calls,
|
||||||
|
_format_messages,
|
||||||
|
_strip_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
tok_name = (
|
||||||
|
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
|
|
||||||
|
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||||
|
prompt, _spans = _format_messages(messages)
|
||||||
|
if add_generation_prompt:
|
||||||
|
prompt = prompt + "Assistant: "
|
||||||
|
|
||||||
|
encoded = tokenizer(prompt, return_tensors="pt")
|
||||||
|
ids = encoded["input_ids"]
|
||||||
|
attn = encoded.get("attention_mask")
|
||||||
|
if attn is None and tokenizer.pad_token_id is not None:
|
||||||
|
attn = ids != tokenizer.pad_token_id
|
||||||
|
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||||
|
attn = attn.bool()
|
||||||
|
|
||||||
|
device = getattr(getattr(policy, "config", None), "device", None)
|
||||||
|
if device is not None:
|
||||||
|
try:
|
||||||
|
ids = ids.to(device)
|
||||||
|
if attn is not None and hasattr(attn, "to"):
|
||||||
|
attn = attn.to(device)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||||
|
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_text_batch_chat(
|
||||||
|
policy: Any,
|
||||||
|
prompt_messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""SmolVLA2 (SmolVLM2) text batch — chat-template tokenization.
|
||||||
|
|
||||||
|
Reuses ``_strip_lerobot_blocks`` so the inference prompt shape
|
||||||
|
matches the training-time chat tokenizer step exactly.
|
||||||
"""
|
"""
|
||||||
from transformers import AutoTokenizer # noqa: PLC0415
|
from transformers import AutoTokenizer # noqa: PLC0415
|
||||||
|
|
||||||
|
|||||||
@@ -67,8 +67,14 @@ logger = logging.getLogger("lerobot.smolvla2.runtime")
|
|||||||
|
|
||||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||||
p = argparse.ArgumentParser(
|
p = argparse.ArgumentParser(
|
||||||
prog="lerobot-smolvla2-runtime",
|
# prog defaults to the invoked command name, so this reads
|
||||||
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
|
# correctly whether run as lerobot-smolvla2-runtime or
|
||||||
|
# lerobot-pi052-runtime.
|
||||||
|
description=(
|
||||||
|
"Interactive REPL runtime for a trained hierarchical VLA "
|
||||||
|
"checkpoint (SmolVLA2 or PI052 — policy type is read from "
|
||||||
|
"the checkpoint)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--policy.path",
|
"--policy.path",
|
||||||
|
|||||||
Reference in New Issue
Block a user