examples(port_datasets): SLURM+datatrove RoboCasa composite_seen build

Parallel variant of build_robocasa_composite_seen.py modeled after the
existing slurm_port_shards.py / slurm_aggregate_shards.py pattern.

Two-phase datatrove pipeline:
  * Phase 1 DOWNLOAD: tasks=16 (one per RoboCasa composite_seen task),
    each worker downloads its assigned tar via RoboCasa's own
    download_datasets helper. Network-bound, idempotent.
  * Phase 2 AGGREGATE: tasks=1, single worker calls aggregate_datasets
    over the 16 extracted directories. Submitted with depends=phase1 so
    SLURM only releases it once all 16 downloads succeed.

Reuses the COMPOSITE_SEEN_TASKS list and per-task download/resolve
helpers from the single-machine script via aliased imports — single
source of truth for 'what does it mean to download a composite_seen
task'.

Local (--slurm 0) mode runs the two phases sequentially in-process for
debugging on a workstation.

Usage on SLURM:
    uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \
        --output-dir=/scratch/${USER}/robocasa_composite_seen \
        --hub-repo-id=${HF_USER}/robocasa_composite_seen \
        --logs-dir=/scratch/${USER}/logs/robocasa \
        --partition=cpu --push-to-hub

Prereq: uv sync --extra annotations  (pulls datatrove)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 14:10:05 +02:00
parent 9c3d5ab7ce
commit a088c10c80
31 changed files with 666 additions and 2432 deletions
@@ -0,0 +1,541 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build the merged RoboCasa composite_seen dataset on SLURM via datatrove.
Two-phase pipeline modeled after ``slurm_port_shards.py`` +
``slurm_aggregate_shards.py``:
* Phase 1 — DOWNLOAD (parallel, 16 tasks, 1 per worker):
Each datatrove worker downloads one of the 16 RoboCasa composite_seen task
archives (``v1.0/target/composite/<Task>/<date>/lerobot.tar``) via
RoboCasa's own ``download_datasets`` helper. Idempotent — already-extracted
tasks are skipped. Network-bound, CPU-light.
* Phase 2 — AGGREGATE (single worker, depends on phase 1):
One worker calls ``aggregate_datasets`` over the 16 extracted directories,
producing a single combined LeRobotDataset. Validates fps / robot_type /
features, unifies task indices, concatenates videos + parquet, recomputes
stats. CPU + disk heavy.
When run under SLURM, phase 2 is submitted with ``depends=phase_1_executor``
so the scheduler only releases it after every download task succeeds.
Local (``--slurm 0``) execution runs the two phases sequentially in the
current process — useful for debugging on a workstation.
Usage on SLURM::
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
--output-dir=/scratch/${USER}/robocasa_composite_seen \\
--hub-repo-id=${HF_USER}/robocasa_composite_seen \\
--logs-dir=/scratch/${USER}/logs/robocasa \\
--partition=cpu \\
--download-cpus=4 --download-mem=8G \\
--aggregate-cpus=16 --aggregate-mem-per-cpu=4G \\
--push-to-hub
Local debug (sequential, single process)::
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
--output-dir=/tmp/robocasa_composite_seen \\
--slurm=0 \\
--tasks=PrepareCoffee,KettleBoiling
Prereqs: ``robocasa`` + ``robosuite`` installed (see
``docs/source/benchmarks/robocasa.mdx``); ``datatrove`` installed via the
``annotations`` extra (``uv sync --extra annotations``).
"""
from __future__ import annotations
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
# Reuse the per-task helpers + canonical task list from the single-machine
# script so both runners share one source of truth for "what does it mean to
# download a composite_seen task". The helpers are spelled with leading
# underscores there (module-private), but the slurm runner is a legitimate
# in-tree consumer, so we alias them to clean names here.
from lerobot.scripts.build_robocasa_composite_seen import (
COMPOSITE_SEEN_TASKS,
_download_task as download_task,
_require_robocasa as require_robocasa,
_resolve_task_root as resolve_task_root,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Pipeline steps
# ---------------------------------------------------------------------------
class DownloadRoboCasaTask(PipelineStep):
"""Phase 1 — download the task assigned to this rank.
Each datatrove worker is given a rank in ``[0, world_size)``. With
``tasks == len(task_names)`` each worker owns exactly one task; with
fewer workers than tasks, datatrove load-balances tasks across workers
using the standard ``rank::world_size`` slicing.
"""
def __init__(self, task_names: list[str], *, overwrite: bool = False):
super().__init__()
self.task_names = list(task_names)
self.overwrite = overwrite
def run(self, data=None, rank: int = 0, world_size: int = 1):
from lerobot.utils.utils import init_logging # noqa: PLC0415
init_logging()
require_robocasa()
# Standard datatrove slicing: each rank owns the subset
# ``task_names[rank::world_size]``. When ``world_size ==
# len(task_names)`` this is exactly one task per rank.
my_tasks = self.task_names[rank::world_size]
if not my_tasks:
logger.info("rank %d/%d: no tasks assigned", rank, world_size)
return
for task in my_tasks:
logger.info("rank %d/%d: downloading %s", rank, world_size, task)
root = download_task(task, overwrite=self.overwrite)
logger.info("rank %d/%d: %s extracted at %s", rank, world_size, task, root)
class AggregateRoboCasaShards(PipelineStep):
"""Phase 2 — merge all 16 extracted directories into one LeRobotDataset.
``aggregate_datasets`` parallelizes internally; only rank 0 runs the
merge (mirrors the DROID ``slurm_aggregate_shards.py`` convention).
"""
def __init__(
self,
task_names: list[str],
*,
output_repo_id: str,
output_dir: Path,
push_to_hub: bool,
private: bool,
):
super().__init__()
self.task_names = list(task_names)
self.output_repo_id = output_repo_id
self.output_dir = Path(output_dir)
self.push_to_hub = push_to_hub
self.private = private
def run(self, data=None, rank: int = 0, world_size: int = 1):
from lerobot.utils.utils import init_logging # noqa: PLC0415
init_logging()
if rank != 0:
logger.info("rank %d: aggregation runs on rank 0 only — skipping", rank)
return
require_robocasa()
from lerobot.datasets import aggregate_datasets # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
# Resolve each task's local extraction root. After phase 1 these
# all exist on disk under robocasa.macros.DATASET_BASE_DIR; if any
# are missing, fail loudly so the operator knows phase 1 didn't
# cleanly complete for that task.
roots: list[Path] = []
missing: list[str] = []
for task in self.task_names:
root = resolve_task_root(task)
if not root.exists():
missing.append(f"{task} -> {root}")
else:
roots.append(root)
if missing:
raise RuntimeError(
"Phase 1 did not produce extracted directories for: "
+ ", ".join(missing)
+ " — re-run the download phase before aggregating."
)
# ``aggregate_datasets`` uses ``repo_ids`` purely for logging /
# the unified task table when ``roots=`` is supplied; the actual
# data is loaded from each root directly.
repo_ids = [f"robocasa/{task}_target_human" for task in self.task_names]
logger.info(
"Aggregating %d source datasets into %s at %s",
len(roots),
self.output_repo_id,
self.output_dir,
)
aggregate_datasets(
repo_ids=repo_ids,
aggr_repo_id=self.output_repo_id,
roots=roots,
aggr_root=self.output_dir,
)
logger.info("Aggregation complete.")
if self.push_to_hub:
merged = LeRobotDataset(
repo_id=self.output_repo_id,
root=self.output_dir,
)
logger.info(
"Pushing %s to the Hub (private=%s, %d episodes, %d frames)",
self.output_repo_id,
self.private,
merged.num_episodes,
merged.num_frames,
)
merged.push_to_hub(
private=self.private,
upload_large_folder=True,
tags=["lerobot", "robocasa", "composite_seen", "manipulation"],
)
logger.info(
"Push complete: https://huggingface.co/datasets/%s",
self.output_repo_id,
)
# ---------------------------------------------------------------------------
# Executors
# ---------------------------------------------------------------------------
def make_download_executor(
*,
task_names: list[str],
overwrite: bool,
job_name: str,
logs_dir: Path,
workers: int,
partition: str | None,
cpus_per_task: int,
mem: str,
time: str,
slurm: bool,
):
"""Phase-1 executor: parallel downloads, one task per worker by default."""
pipeline = [DownloadRoboCasaTask(task_names, overwrite=overwrite)]
logging_dir = str(logs_dir / job_name)
if slurm:
return SlurmPipelineExecutor(
pipeline=pipeline,
logging_dir=logging_dir,
job_name=job_name,
tasks=len(task_names), # one shard per RoboCasa task
workers=workers,
time=time,
partition=partition,
cpus_per_task=cpus_per_task,
sbatch_args={"mem": mem},
)
return LocalPipelineExecutor(
pipeline=pipeline,
logging_dir=logging_dir,
tasks=len(task_names),
workers=min(workers, len(task_names)),
)
def make_aggregate_executor(
*,
task_names: list[str],
output_repo_id: str,
output_dir: Path,
push_to_hub: bool,
private: bool,
job_name: str,
logs_dir: Path,
partition: str | None,
cpus_per_task: int,
mem_per_cpu: str,
time: str,
slurm: bool,
depends: SlurmPipelineExecutor | None,
):
"""Phase-2 executor: single worker, aggregates the extracted shards."""
pipeline = [
AggregateRoboCasaShards(
task_names,
output_repo_id=output_repo_id,
output_dir=output_dir,
push_to_hub=push_to_hub,
private=private,
)
]
logging_dir = str(logs_dir / job_name)
if slurm:
return SlurmPipelineExecutor(
pipeline=pipeline,
logging_dir=logging_dir,
job_name=job_name,
tasks=1,
workers=1,
time=time,
partition=partition,
cpus_per_task=cpus_per_task,
sbatch_args={"mem-per-cpu": mem_per_cpu},
depends=depends, # SLURM job dependency on phase 1
)
return LocalPipelineExecutor(
pipeline=pipeline,
logging_dir=logging_dir,
tasks=1,
workers=1,
)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Build the merged RoboCasa composite_seen LeRobotDataset on SLURM "
"via datatrove (download in parallel, aggregate sequentially)."
),
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
# I/O.
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Local directory for the merged dataset (will be created).",
)
parser.add_argument(
"--hub-repo-id",
type=str,
default=None,
help=(
"Hub repo_id for the merged dataset (e.g. "
"``yourname/robocasa_composite_seen``). Required for "
"``--push-to-hub``; also becomes the merged dataset's "
"canonical ``repo_id``."
),
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the merged dataset to the Hub after aggregation.",
)
parser.add_argument(
"--private",
action="store_true",
help="When pushing, create the Hub repo as private.",
)
parser.add_argument(
"--logs-dir",
type=Path,
default=Path("./logs/robocasa"),
help="Path to datatrove logs directory (used for stdout/stderr and "
"phase coordination).",
)
parser.add_argument(
"--tasks",
type=str,
default=None,
help="Comma-separated task names overriding the default 16 "
"composite_seen list (debug / smoke-test).",
)
parser.add_argument(
"--overwrite-download",
action="store_true",
help="Force re-download even if the local extraction looks complete.",
)
# SLURM controls.
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over SLURM (``1``) or locally / sequentially (``0``).",
)
parser.add_argument(
"--partition",
type=str,
default=None,
help="SLURM partition. A CPU partition is sufficient — no GPU needed.",
)
# Phase-1 (download) sizing.
parser.add_argument(
"--download-workers",
type=int,
default=16,
help="Number of parallel SLURM workers for the download phase. "
"Default matches the number of composite_seen tasks (16).",
)
parser.add_argument(
"--download-cpus",
type=int,
default=4,
help="CPUs per download worker (the work is network- and "
"tar-extraction-bound).",
)
parser.add_argument(
"--download-mem",
type=str,
default="8G",
help="Total memory per download worker.",
)
parser.add_argument(
"--download-time",
type=str,
default="06:00:00",
help="SLURM wall-clock limit per download worker (HH:MM:SS).",
)
# Phase-2 (aggregate) sizing.
parser.add_argument(
"--aggregate-cpus",
type=int,
default=16,
help="CPUs for the aggregation worker (ffmpeg + parquet I/O parallelize).",
)
parser.add_argument(
"--aggregate-mem-per-cpu",
type=str,
default="2G",
help="SLURM mem-per-cpu for the aggregation worker.",
)
parser.add_argument(
"--aggregate-time",
type=str,
default="12:00:00",
help="SLURM wall-clock limit for aggregation (HH:MM:SS). Tens-of-GB "
"merges can take several hours.",
)
# Job naming.
parser.add_argument(
"--download-job-name",
type=str,
default="robocasa_dl",
help="SLURM job name for phase 1.",
)
parser.add_argument(
"--aggregate-job-name",
type=str,
default="robocasa_agg",
help="SLURM job name for phase 2.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
)
return parser.parse_args()
def main() -> int:
args = parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level),
format="[%(levelname)s] %(name)s: %(message)s",
)
tasks = (
[t.strip() for t in args.tasks.split(",") if t.strip()]
if args.tasks
else list(COMPOSITE_SEEN_TASKS)
)
if not tasks:
raise SystemExit("No tasks selected.")
if args.push_to_hub and not args.hub_repo_id:
raise SystemExit("--push-to-hub requires --hub-repo-id.")
output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen"
slurm = args.slurm == 1
logger.info(
"Phase 1 (download) — %d tasks across %d workers (slurm=%s)",
len(tasks),
min(args.download_workers, len(tasks)),
slurm,
)
download_executor = make_download_executor(
task_names=tasks,
overwrite=args.overwrite_download,
job_name=args.download_job_name,
logs_dir=args.logs_dir,
workers=args.download_workers,
partition=args.partition,
cpus_per_task=args.download_cpus,
mem=args.download_mem,
time=args.download_time,
slurm=slurm,
)
logger.info(
"Phase 2 (aggregate) — single worker, output: %s (push_to_hub=%s)",
output_repo_id,
args.push_to_hub,
)
aggregate_executor = make_aggregate_executor(
task_names=tasks,
output_repo_id=output_repo_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
private=args.private,
job_name=args.aggregate_job_name,
logs_dir=args.logs_dir,
partition=args.partition,
cpus_per_task=args.aggregate_cpus,
mem_per_cpu=args.aggregate_mem_per_cpu,
time=args.aggregate_time,
slurm=slurm,
depends=download_executor if slurm else None,
)
if slurm:
# Submitting the aggregate executor with ``depends=download_executor``
# also submits the download executor — SlurmPipelineExecutor walks
# the dependency chain and submits each job once with the right
# ``--dependency=afterok:<jobid>`` arg.
aggregate_executor.run()
else:
# Local sequential: run download to completion, then aggregate.
download_executor.run()
aggregate_executor.run()
logger.info("Done. Merged dataset at %s.", args.output_dir)
return 0
if __name__ == "__main__":
raise SystemExit(main())
+2 -6
View File
@@ -308,12 +308,8 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
# Interactive hierarchical-VLA runtime. The same entry point drives both
# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the
# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is
# an alias so the command name isn't misleading for PI052 users.
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
@@ -129,8 +129,8 @@ class Executor:
written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Persist the tool catalog to meta/info.json so chat-template
# consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read
# Persist the tool catalog to meta/info.json so downstream
# consumers (PI052 / Pi0.5 / dataset visualizer) can read
# it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and
# additive: anything the user pre-populated is preserved; we only
# ensure the canonical ``say`` schema is present.
+4 -5
View File
@@ -20,11 +20,10 @@
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# SmolVLA2 note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the chat tokenizer
# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM
# head learns to emit exactly the marker the runtime parses back
# (`_split_plan_and_say`).
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
@@ -20,11 +20,10 @@
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# SmolVLA2 note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the chat tokenizer
# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM
# head learns to emit exactly the marker the runtime parses back
# (`_split_plan_and_say`).
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
@@ -1,5 +1,4 @@
# subtasks_vqa — Hi-Robot blend, shared between SmolVLA2 (SmolVLM2
# backbone) and PI052 (PaliGemma backbone).
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
#
# Trains two things only: subtasks and VQA. Plan and memory are
# intentionally left out — keeps the prompt short and the training
@@ -10,9 +9,8 @@
# low_level_execution — flow loss with [images, subtask, state].
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Each backbone's text tokenizer renders these messages differently
# (SmolVLA2 uses the chat template; PI052 concatenates as plain
# ``Role: content`` text), but the recipe spec is identical.
# PI052's text tokenizer renders these messages as plain
# ``Role: content`` text (PaliGemma is not chat-pretrained).
blend:
-2
View File
@@ -27,7 +27,6 @@ from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .utils import make_robot_action, prepare_observation_for_inference
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
@@ -52,7 +51,6 @@ __all__ = [
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"SmolVLA2Config",
"TDMPCConfig",
"VQBeTConfig",
"WallXConfig",
+2 -21
View File
@@ -144,10 +144,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
elif name == "sarm":
from .sarm.modeling_sarm import SARMRewardModel
@@ -180,8 +176,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"smolvla", "reward_classifier", "wall_x".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
"pi052", "sac", "smolvla", "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -212,10 +208,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
@@ -424,17 +416,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "smolvla2":
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
# MUST come before the SmolVLAConfig isinstance check below
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
processors = make_smolvla2_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
@@ -32,8 +32,8 @@ This is the dual-head co-training pattern from the paper:
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
inference into a text-prediction step followed by an action-prediction
step, which mirrors what ``SmolVLA2Runtime`` already does on a
SmolVLM2 backbone.
step, which the multi-rate ``PI052Runtime`` (in
``lerobot.policies.pi052.inference``) drives at separate rates.
"""
from dataclasses import dataclass
@@ -48,10 +48,9 @@ from ..pi05.configuration_pi05 import PI05Config
class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head
config. Same recipe-driven training surface; the only difference is
which backbone the policy uses (PaliGemma here vs SmolVLM2 there).
The flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
Recipe-driven dual-head training: the flow head supervises actions,
the LM head supervises subtask / plan / memory / VQA text. The
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
"""
# Recipe / language stack ---------------------------------------------
@@ -64,17 +63,17 @@ class PI052Config(PI05Config):
apply_chat_template: bool = False
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
chat template. So unlike SmolVLA2 we don't apply one. The recipe
renderer's output is concatenated as a plain prefix + assistant
suffix instead, mirroring how the π0.5 paper's high-level inference
samples text auto-regressively after the prefix."""
chat template, so we don't apply one. The recipe renderer's output
is concatenated as a plain prefix + assistant suffix instead,
mirroring how the π0.5 paper's high-level inference samples text
auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
# Paper §IV.D uses α=10 between the flow and text terms, assuming
# text is a rare auxiliary task. With the recipe stack the flow-only
# `low_level` branch fires on a large share of samples, so α=10
# swamps the LM head and collapses generation into degenerate
# repetition. We use the milder 5:1 split (matches SmolVLA2Config).
# repetition. We use the milder 5:1 split here.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
@@ -93,8 +92,8 @@ class PI052Config(PI05Config):
hierarchical inference."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# Same regulariser surface as SmolVLA2: randomly drop non-target
# context messages so the LM head learns to handle missing /
# Randomly drop non-target context messages so the LM head learns
# to handle missing /
# stale plan / memory at inference. Defaults to 0.0 so behaviour
# is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 inference / runtime orchestration.
"""PI052 inference / runtime orchestration.
Multi-rate runtime that mirrors the recipe-time training shape:
@@ -22,12 +22,12 @@ Multi-rate runtime that mirrors the recipe-time training shape:
ask_vqa_* AskVQAFwd (event: stdin question)
speech tool calls DispatchToolCalls (event: tool_call_pending)
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
calls ``run()``.
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
``run()``.
"""
from .repl import StdinReader
from .runtime import SmolVLA2Runtime
from .runtime import PI052Runtime
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
from .steps import (
AskVQAFwd,
@@ -44,7 +44,7 @@ from .ui import make_state_panel, print_robot_lines, print_user_line
__all__ = [
# runtime
"SmolVLA2Runtime",
"PI052Runtime",
"StdinReader",
# state helpers
"initial_runtime_state",
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stdin REPL event collector for the SmolVLA2 runtime.
"""Stdin REPL event collector for the PI052 runtime.
Reads non-blocking stdin lines, classifies each one heuristically:
@@ -23,7 +23,7 @@ Reads non-blocking stdin lines, classifies each one heuristically:
Plugged into the runtime via ``event_collector=StdinReader().poll``.
Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
directly in its REPL / autonomous loops and does *not* wire this
collector; it's kept as the documented embedding hook and for tests.
"""
@@ -92,7 +92,7 @@ class StdinReader:
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
state["task"] = task
print(f"[smolvla2] Task: {task}", flush=True)
print(f"[pi052] Task: {task}", flush=True)
self._seen_first_line = True
return
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 runtime loop.
"""PI052 runtime loop.
Threads the multi-rate inference pipeline together with a stdin REPL
event collector, drives ticks through :class:`TickClock`, and prints
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
@dataclass
class SmolVLA2Runtime:
class PI052Runtime:
"""Compose the inference pipeline and drive it tick-by-tick."""
policy: Any
@@ -195,11 +195,11 @@ class SmolVLA2Runtime:
def _flush_logs(self) -> None:
for line in self.state.get("log_lines") or []:
print(f"[smolvla2] {line}", flush=True)
print(f"[pi052] {line}", flush=True)
def _on_shutdown(self) -> None:
# Drain any queued action chunks safely.
queue = self.state.get("action_queue")
if isinstance(queue, deque):
queue.clear()
print("[smolvla2] runtime stopped", flush=True)
print("[pi052] runtime stopped", flush=True)
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference steps for the SmolVLA2 multi-rate runtime.
"""Inference steps for the PI052 multi-rate runtime.
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
the runtime applies them in order each tick. When a step's trigger
@@ -98,7 +98,7 @@ class LowLevelForward(InferenceStep):
if not state.get("task"):
return None
# SmolVLA produces *action chunks* (typically 50 steps via
# PI052 produces *action chunks* (typically 50 steps via
# flow-matching). Every step gets dispatched to the robot;
# popping one per dispatch tick is essentially free. Only
# generate a new chunk once the previous one has fully
@@ -256,34 +256,11 @@ def _build_text_batch(
) -> dict[str, Any]:
"""Tokenize chat messages into the batch ``select_message`` expects.
Dispatches on the policy backbone so one runtime drives both:
* ``smolvla2`` (SmolVLM2) chat template via ``apply_chat_template``.
* ``pi052`` (PaliGemma) flat ``Role: content`` text, since
PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``).
"""
if getattr(getattr(policy, "config", None), "type", "") == "pi052":
return _build_text_batch_pi052(
policy, prompt_messages, add_generation_prompt=add_generation_prompt
)
return _build_text_batch_chat(
policy, prompt_messages, add_generation_prompt=add_generation_prompt
)
def _build_text_batch_pi052(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""PI052 text batch — flat ``User: … \\nAssistant: …`` prompt.
PaliGemma ships no chat template, so PI052 trains on the plain
role-prefixed concatenation built by ``PI052TextTokenizerStep``.
Reuses that exact formatter so the inference prefix matches
training. ``add_generation_prompt`` appends the bare ``Assistant: ``
header the LM head continues from.
PI052's backbone (PaliGemma) ships no chat template, so we train on
a plain role-prefixed concatenation built by
``PI052TextTokenizerStep``. We reuse that exact formatter so the
inference prefix matches training; ``add_generation_prompt`` appends
the bare ``Assistant: `` header the LM head continues from.
"""
import torch # noqa: PLC0415
from transformers import AutoTokenizer # noqa: PLC0415
@@ -315,99 +292,6 @@ def _build_text_batch_pi052(
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
attn = attn.bool()
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _build_text_batch_chat(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""SmolVLA2 (SmolVLM2) text batch — chat-template tokenization.
Reuses ``_strip_lerobot_blocks`` so the inference prompt shape
matches the training-time chat tokenizer step exactly.
"""
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
# Reuse the *exact* normaliser that the training-time chat
# tokenizer step uses (``_strip_lerobot_blocks``). It handles all
# the cases the SmolVLM chat template expects:
# * ``content: list[block]`` → keep text blocks, drop images
# * ``content: None`` → ``[{type: text, text: ""}]``
# * ``content: str`` / anything else → ``[{type: text, text: str(content)}]``
# Doing it any other way creates a training/inference mismatch in
# exactly the prompt shape the model was supervised on. Also
# strips ``stream`` / ``target`` recipe metadata.
from lerobot.policies.smolvla2.chat_processor_smolvla2 import ( # noqa: PLC0415
_strip_lerobot_blocks,
)
text_messages = [_strip_lerobot_blocks(m) for m in prompt_messages]
encoded = tokenizer.apply_chat_template(
text_messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
)
# ``apply_chat_template`` can return any of:
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
# (newer transformers, especially via processor.apply_chat_template
# forwarding through here).
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
# attention mask when available so we don't have to re-derive it
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
attn: Any = None
if hasattr(encoded, "input_ids"):
ids = encoded.input_ids
attn = getattr(encoded, "attention_mask", None)
elif isinstance(encoded, dict) and "input_ids" in encoded:
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
else:
ids = encoded
if isinstance(ids, list):
if ids and isinstance(ids[0], list):
ids = ids[0]
import torch # noqa: PLC0415
ids = torch.tensor(ids, dtype=torch.long)
if hasattr(ids, "ndim") and ids.ndim == 1:
ids = ids.unsqueeze(0)
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
elif isinstance(attn, list):
import torch # noqa: PLC0415
attn = torch.tensor(attn, dtype=torch.long)
if attn.ndim == 1:
attn = attn.unsqueeze(0)
# SmolVLA's ``eager_attention_forward`` does
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
# requires a *bool* condition tensor; ``BatchEncoding``'s
# attention_mask is typically Long (0/1). Cast so the prefix
# forward doesn't blow up with ``where expected condition to be a
# boolean tensor, but got a tensor with dtype Long``.
if attn is not None and hasattr(attn, "dtype"):
import torch as _torch # noqa: PLC0415
if attn.dtype != _torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
@@ -418,7 +302,7 @@ def _build_text_batch_chat(
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move lang tokens to %s: %s", device, exc)
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
@@ -1022,12 +906,6 @@ def _generate_with_policy(
"temperature": temperature,
"top_p": top_p,
}
# Only pass ``suppress_loc_tokens`` to backbones that accept it
# (pi052). SmolVLA2's ``select_message`` does not, so we omit
# the kwarg there to avoid breaking the shared runtime.
import inspect # noqa: PLC0415
if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters:
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
return policy.select_message(batch, **kwargs)
except Exception as exc: # noqa: BLE001
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
"""Trigger primitives for PI052's multi-rate inference runtime.
Mirrors the plan's Section "Runtime orchestration": each
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rich-based REPL layout for the SmolVLA2 runtime.
"""Rich-based REPL layout for the PI052 runtime.
Two-zone terminal layout:
@@ -97,7 +97,7 @@ def make_state_panel(state: dict[str, Any]) -> Any:
)
return Panel(
table,
title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}",
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
border_style="cyan",
)
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive VQA for the SmolVLA2 runtime.
"""Interactive VQA for the PI052 runtime.
In ``/vlm`` mode a typed line is treated as a VQA question. This module
runs the full interactive flow:
@@ -379,7 +379,7 @@ def handle_vqa_query(
# Feed the FULL observation (every camera + state) to the VLM. The
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
# stripped before tokenization — the actual frames reach the model
# via SmolVLA's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
# consumes *all* ``config.image_features`` regardless of which
# camera the sub-recipe was tagged for. So the model always sees
# every camera; the operator never has to name one to ask.
+9 -10
View File
@@ -29,10 +29,10 @@ A thin subclass of :class:`PI05Policy` that:
with α controllable via ``config.flow_loss_weight``.
The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see
``lerobot.policies.smolvla2.inference``) can drive this policy too —
both expose ``predict_action_chunk`` for the action expert and
``select_message`` for the LM head.
The multi-rate inference runtime in ``lerobot.policies.pi052.inference``
(driven by the ``lerobot-pi052-runtime`` CLI) sits on top of this:
``predict_action_chunk`` for the action expert and ``select_message``
for the LM head.
"""
from __future__ import annotations
@@ -380,9 +380,9 @@ class PI052Policy(PI05Policy):
for p in backbone.lm_head.parameters():
p.requires_grad_(True)
# The text model's final norm and last transformer block —
# mirror SmolVLA2's logic, which finds these dynamically by
# the trainable=False parameters that point at the head's
# neighbourhood.
# find them dynamically by walking up from the LM head so we
# don't hard-code module names that may drift across transformers
# versions.
text_model = getattr(backbone, "model", None)
text_model = getattr(text_model, "language_model", text_model)
if text_model is None:
@@ -934,9 +934,8 @@ class PI052Policy(PI05Policy):
) -> str:
"""Generate text continuation from a multimodal prefix.
Mirrors ``SmolVLA2Policy.select_message`` so the same
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
can drive π0.5 v2 unchanged.
Consumed by :class:`lerobot.policies.pi052.inference.PI052Runtime`
for the high-level / VQA / memory-update text streams.
``suppress_loc_tokens`` masks PaliGemma's reserved ``<locDDDD>``
ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's
@@ -182,8 +182,7 @@ def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/`` (same lookup rules as
``make_smolvla2_pre_post_processors``).
``src/lerobot/configs/``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
@@ -14,7 +14,7 @@
"""π0.5 v2 text-tokenisation step.
PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on
PaliGemma is *not* chat-pretrained, so we can't lean on
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
messages as plain text with simple ``User: ... Assistant: ...`` role
delimiters — matching the prompt format π0.5 uses in the paper
@@ -30,8 +30,7 @@ Outputs:
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
those positions via the PaliGemma ``lm_head``.
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``. Same
semantics as the SmolVLA2 step.
target messages has ``message_streams[i] == "low_level"``.
"""
from __future__ import annotations
@@ -54,11 +53,9 @@ logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the
# matching SmolVLA2 implementation. Behaviour: when
# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on
# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans
# the LM head will be supervised on.
# Debug helper — when ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N
# samples processed (on rank 0) are pretty-printed with ``[TGT]...[/TGT]``
# markers over the spans the LM head will be supervised on.
# ---------------------------------------------------------------------------
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
@@ -246,8 +243,8 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
# values exceeding the camera's pixel dimensions — they're not pixels.)
# Converting to ``<loc>`` is therefore camera-resolution-independent:
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
# not in the dataset — so the dataset stays backbone-agnostic (SmolVLA2
# keeps the JSON).
# not in the dataset — so the dataset keeps the raw JSON and stays
# backbone-agnostic.
# ---------------------------------------------------------------------------
# The 01000 scale Qwen2.5-VL emits for grounding coordinates.
@@ -424,8 +421,8 @@ def _format_messages(
class PI052TextTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template;
concatenates messages as ``User: ... \\nAssistant: ...`` text.
No chat template; concatenates messages as
``User: ... \\nAssistant: ...`` text.
"""
tokenizer_name: str = "google/paligemma-3b-pt-224"
@@ -613,8 +610,7 @@ class PI052TextTokenizerStep(ProcessorStep):
continue
labels[token_pos] = input_ids[token_pos]
# Scan ALL message streams (not just targets) — see
# ``chat_processor_smolvla2.py`` for rationale: the v2
# Scan ALL message streams (not just targets): the
# ``low_level_execution`` recipe drops ``target: true`` on
# the assistant to avoid trivial copy-from-user text-CE; the
# flow loss still needs to fire, gated by ``stream: low_level``.
@@ -686,7 +682,7 @@ class PI052TextTokenizerStep(ProcessorStep):
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
"""Heuristic content-prefix classifier — mirrors SmolVLA2's."""
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
content = message.get("content")
if isinstance(content, list):
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
-38
View File
@@ -1,38 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
SmolVLA strips the LM head from the SmolVLM backbone because it only does
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
model can train on the full Hi Robot / MEM / ECoT message blend defined in
the steerable annotation plan (PR1 + PR2):
* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
``user_interjection_response``, ``high_level_subtask``) → CE loss on
``lm_head`` over the recipe's target message tokens
* mixed sub-recipes → both losses summed (weighted)
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
sample, else ``False``.
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
the model code stays in one place — only the dual-loss path and the
chat-template processor live here.
"""
from .configuration_smolvla2 import SmolVLA2Config
__all__ = ["SmolVLA2Config"]
@@ -1,649 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2's chat-template tokenization step.
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
``recipe_path`` is set. Reads the rendered messages produced by
``RenderMessagesStep`` (PR 1) and produces:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` —
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
``tools=meta.tools`` (PR 1's catalog).
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
the positions belonging to messages whose index is in
``target_message_indices``. The next commit's modeling forward path
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``. The
modeling forward uses this to gate the flow head.
Image / video content blocks in the rendered messages are dropped
before tokenization — the chat template only handles text, and SmolVLA
already passes camera tensors out-of-band via the standard
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
(``embed_prefix`` puts image embeddings before language embeddings,
matching the chat-template-stripped text order).
"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Debug helper: dump the first N rendered samples to stdout so you can sanity-
# check what the model actually sees before kicking off a long training run.
#
# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ...
#
# Prints the recipe-rendered messages, the chat-templated text (decoded back
# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans
# are supervised by text-CE. Stops after N total dumps to keep training logs
# tractable. Rank-0 only when accelerate sets ``RANK``.
# ---------------------------------------------------------------------------
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
_DUMPED_SO_FAR = 0
def _is_dump_rank() -> bool:
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
try:
return int(rank) == 0
except ValueError:
return True
def _dump_recipe_sample(
*,
messages: list[dict[str, Any]],
token_ids: list[int],
labels: list[int],
predict_actions: bool,
tokenizer: Any,
) -> None:
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
global _DUMPED_SO_FAR
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
return
_DUMPED_SO_FAR += 1
decoded = tokenizer.decode(token_ids, skip_special_tokens=False)
parts: list[str] = []
i = 0
while i < len(labels):
if labels[i] == -100:
j = i
while j < len(labels) and labels[j] == -100:
j += 1
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
i = j
else:
j = i
while j < len(labels) and labels[j] != -100:
j += 1
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
parts.append(f"[TGT]{tgt_text}[/TGT]")
i = j
annotated = "".join(parts)
n_tgt = sum(1 for l in labels if l != -100)
print(
"\n========== RECIPE SAMPLE DUMP "
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
flush=True,
)
print(f" predict_actions: {predict_actions}", flush=True)
print(f" rendered messages ({len(messages)}):", flush=True)
for m in messages:
stream = m.get("stream")
target = m.get("target")
role = m.get("role")
content = m.get("content")
print(f" - role={role} stream={stream} target={target}", flush=True)
print(f" content: {content!r}", flush=True)
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
print(f" decoded (raw):\n {decoded}", flush=True)
print(f" decoded (with target markers):\n {annotated}", flush=True)
print("==============================================\n", flush=True)
@dataclass
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
class SmolVLA2ChatTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
This is the bridge between the recipe stack (PR 1's
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
(next commit, which reads ``text_labels`` / ``predict_actions``).
Pure-text turns and multi-stream targets are both handled.
"""
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
max_length: int = 2048
padding: str = "longest"
padding_side: str = "right"
tools: list[dict[str, Any]] | None = None
# --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up
# ``feat/pi05-prompt-dropout``). At training, drop non-target
# messages whose content was substituted from the named recipe
# binding with the given probability. Forces the model to handle
# missing context at inference — directly attacks the memorisation
# collapse where ``current_subtask=""`` puts the prompt OOD. All
# default to 0.0 (no dropout) so behaviour is identical until
# explicitly opted in via the training config.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
# Optional seed for the per-sample RNG. ``None`` ⇒ use
# ``sample_idx`` derived from the transition (when present), so
# dropout is reproducible across runs but varies per sample.
dropout_seed: int | None = None
def __post_init__(self) -> None:
# Lazy: don't load the tokenizer until the step actually runs,
# so unit tests that import the module without transformers
# installed still pass.
self._tokenizer: Any = None
if self.tools is None:
# Default: no tools rendered into the system prompt. The
# ``say()`` tool was only used by the now-removed
# ``user_interjection_response`` recipe; including its
# schema on every sample adds a long system message to
# the action expert's prefix and creates a train/inference
# mismatch (the inference low-level loop doesn't pass
# tools=, so the chat template doesn't render them).
# Users who actually need tools can set them via
# ``with_tools(meta.tools)``.
self.tools = []
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
"""Override the tools catalog rendered into the system prompt."""
self.tools = list(tools)
return self
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
messages = comp.get("messages")
if not messages:
# No recipe rendering happened — nothing to do; downstream
# falls back to whatever ``task`` is in the transition.
return transition
tokenizer = self._get_tokenizer()
# Pull a sample_idx for the dropout RNG. ``index`` is the
# canonical per-frame key on ``LeRobotDataset`` samples and
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
# absent (e.g. inference) we fall back to 0 which is harmless
# because the dropout probs are also 0 at inference time.
if _is_batched_messages(messages):
indices_iter = _sample_indices(comp.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
sorted(int(i) for i in tgt_indices),
sample_idx=int(s_idx) if s_idx is not None else None,
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
comp.get("message_streams") or [[] for _ in messages],
comp.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
]
else:
sample_idx = _sample_indices(comp.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(comp.get("message_streams") or []),
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
sample_idx=sample_idx,
)
]
# Optional first-N-samples debug dump for sanity-checking what the
# model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES``
# is set; stops globally after the budget is exhausted.
if _DUMP_BUDGET > 0:
# Stream / target metadata live in parallel arrays in
# COMPLEMENTARY_DATA, not on the message dicts themselves
# (the recipe renderer keeps them separate so the chat
# template doesn't choke on unknown keys). Zip them back
# together for the dumper so each printed message shows
# its actual stream + target flag.
if _is_batched_messages(messages):
msgs_iter = messages
streams_iter = comp.get("message_streams") or [[] for _ in messages]
targets_iter = comp.get("target_message_indices") or [[] for _ in messages]
else:
msgs_iter = [messages]
streams_iter = [list(comp.get("message_streams") or [])]
targets_iter = [list(comp.get("target_message_indices") or [])]
for msg, streams, targets, (ids, labels, predict_action) in zip(
msgs_iter, streams_iter, targets_iter, encoded, strict=False
):
target_set = {int(i) for i in targets}
annotated_msgs = []
for i, m in enumerate(msg):
annotated_msgs.append(
{
**m,
"stream": streams[i] if i < len(streams) else None,
"target": True if i in target_set else None,
}
)
_dump_recipe_sample(
messages=annotated_msgs,
token_ids=ids,
labels=labels,
predict_actions=predict_action,
tokenizer=tokenizer,
)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
target_length = self.max_length if self.padding == "max_length" else max(
len(ids) for ids, _, _ in encoded
)
target_length = min(target_length, self.max_length)
ids_batch = []
attn_batch = []
labels_batch = []
predict_actions = []
for ids, labels, predict_action in encoded:
ids = ids[:target_length]
labels = labels[:target_length]
attn = [1] * len(ids)
if len(ids) < target_length:
n_pad = target_length - len(ids)
ids = ids + [pad_id] * n_pad
labels = labels + [-100] * n_pad
attn = attn + [0] * n_pad
ids_batch.append(ids)
attn_batch.append(attn)
labels_batch.append(labels)
predict_actions.append(predict_action)
ids_t = torch.tensor(ids_batch, dtype=torch.long)
attn_t = torch.tensor(attn_batch, dtype=torch.bool)
labels_t = torch.tensor(labels_batch, dtype=torch.long)
predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool)
if not _is_batched_messages(messages):
ids_t = ids_t.squeeze(0)
attn_t = attn_t.squeeze(0)
labels_t = labels_t.squeeze(0)
predict_actions_t = predict_actions_t.squeeze(0)
new_complementary = dict(comp)
# Drop the per-recipe sidecar keys; everything downstream needs
# is now in the tokenized form.
new_complementary.pop("messages", None)
new_complementary.pop("message_streams", None)
new_complementary.pop("target_message_indices", None)
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
# them there — and drop ``task`` so the upstream
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
# tokenize.
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
observation[OBS_LANGUAGE_TOKENS] = ids_t
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
new_complementary["text_labels"] = labels_t
new_complementary["predict_actions"] = predict_actions_t
new_complementary.pop("task", None)
new_transition = dict(transition)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
sample_idx: int | None = None,
) -> tuple[list[int], list[int], bool]:
# Apply per-component prompt dropout *before* tokenisation, so
# the dropped messages don't contribute tokens or label-mask
# positions at all. Re-maps ``target_indices`` to account for
# removed messages.
messages, target_indices = self._apply_prompt_dropout(
messages, target_indices, sample_idx
)
# Flatten ``tool_calls`` into a textual ``<say>...</say>`` marker
# *before* the chat template sees them, so the model is trained
# to emit the same marker the inference parser
# (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``.
messages = [_flatten_say_tool_calls(m) for m in messages]
text_messages = [_strip_lerobot_blocks(m) for m in messages]
full_ids = tokenizer.apply_chat_template(
text_messages,
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_ids = _as_token_ids(full_ids)
labels = [-100] * len(full_ids)
for tgt in target_indices:
prefix_ids = tokenizer.apply_chat_template(
text_messages[:tgt],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
full_through_target = tokenizer.apply_chat_template(
text_messages[: tgt + 1],
tools=self.tools,
add_generation_prompt=False,
tokenize=True,
return_tensors=None,
)
prefix_ids = _as_token_ids(prefix_ids)
full_through_target = _as_token_ids(full_through_target)
start = len(prefix_ids)
end = min(len(full_through_target), len(full_ids))
for pos in range(start, end):
labels[pos] = int(full_ids[pos])
# ``predict_actions`` is True iff this sample's recipe declares
# at least one ``low_level`` message — regardless of whether
# it's a target. The ``low_level_execution`` recipe in v2 uses
# ``stream: low_level`` on both user and assistant turns but
# only renders the *user* subtask (no text-CE target on the
# assistant) to avoid trivial "copy previous turn" supervision.
# Scanning targets alone would miss this sample's action loss.
predict_actions = any(s == "low_level" for s in message_streams)
return [int(i) for i in full_ids], labels, predict_actions
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
sample_idx: int | None,
) -> tuple[list[dict[str, Any]], list[int]]:
"""Probabilistically drop non-target context messages.
Heuristic content sniffing — matches the prefix strings that
``subtask_mem_vqa_speech.yaml``'s recipes use when injecting plan /
memory / subtask / interjection content. Anything else is
kept unchanged. Target messages are never dropped (we still
need their tokens for supervision).
Returns ``(new_messages, new_target_indices)`` where the
indices are re-mapped to point at the same target turns in
the trimmed list.
"""
probs = {
"plan": float(self.plan_dropout_prob or 0.0),
"memory": float(self.memory_dropout_prob or 0.0),
"subtask": float(self.subtask_dropout_prob or 0.0),
"interjection": float(self.interjection_dropout_prob or 0.0),
}
if not any(p > 0.0 for p in probs.values()):
return messages, target_indices
# Deterministic per-sample RNG so dropout is reproducible
# across runs (matters for debugging / repro) but varies
# frame-to-frame.
import random # noqa: PLC0415
seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0)
rng = random.Random(int(seed_int) & 0xFFFFFFFF)
target_set = set(target_indices)
keep_flags: list[bool] = []
for i, msg in enumerate(messages):
if i in target_set:
keep_flags.append(True)
continue
kind = _classify_message_for_dropout(msg)
if kind and rng.random() < probs.get(kind, 0.0):
keep_flags.append(False)
else:
keep_flags.append(True)
new_messages = [m for m, keep in zip(messages, keep_flags) if keep]
# Re-map target_indices: each old index drops by the count of
# falsy flags before it.
new_target_indices: list[int] = []
for old_idx in target_indices:
dropped_before = sum(1 for k in keep_flags[:old_idx] if not k)
new_target_indices.append(old_idx - dropped_before)
return new_messages, sorted(new_target_indices)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass-through; this step writes runtime tensors not features."""
return features
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_tokenizer(self): # noqa: ANN202
if self._tokenizer is not None:
return self._tokenizer
try:
from transformers import AutoTokenizer # noqa: PLC0415
except ImportError as exc: # pragma: no cover
raise ImportError(
"SmolVLA2ChatTokenizerStep requires transformers. "
"`pip install lerobot[transformers-dep]`."
) from exc
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
self._tokenizer.pad_token = self._tokenizer.eos_token
return self._tokenizer
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
The recipe DSL allows authors to write multimodal content like
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
``path``). The actual image tensor flows through SmolVLA's
``OBS_IMAGES_*`` channels separately; the chat template only needs
the text. So we strip non-text blocks before tokenizing.
"""
new = dict(message)
content = new.get("content")
if isinstance(content, list):
text_parts: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
new["content"] = text_parts or [{"type": "text", "text": ""}]
elif content is None:
new["content"] = [{"type": "text", "text": ""}]
else:
new["content"] = [{"type": "text", "text": str(content)}]
if "tool_calls" in new and not new["tool_calls"]:
# Drop empty tool_calls — some templates render them as a
# spurious empty marker.
new.pop("tool_calls")
# ``stream`` and ``target`` were recipe metadata; templates don't
# know them and may warn or crash.
new.pop("stream", None)
new.pop("target", None)
return new
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to plain text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
t = block.get("text")
if isinstance(t, str):
parts.append(t)
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a textual ``<say>...</say>``
marker inside the message content (Pi 0.5-style flat tool-call
serialization).
SmolVLM's chat template would otherwise render ``tool_calls`` as a
structured JSON ``<tool_call>`` block, so the LM head learns to emit
JSON — but the inference parser ``_split_plan_and_say`` looks for a
``<say>...</say>`` marker (``_SAY_RE``). Rewriting the call into the
content text *before* ``apply_chat_template`` aligns the two: the
template only ever tokenizes plain text, and the supervised target
span trains the model to produce the exact marker the runtime reads.
Messages without ``say`` tool calls are returned unchanged.
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
if not say_texts:
# No ``say`` calls (or empty text) — drop the structured calls so
# the template doesn't render a stray JSON block, but leave the
# content alone.
new = dict(message)
new.pop("tool_calls", None)
return new
new = dict(message)
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
new.pop("tool_calls", None)
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
"""Best-effort classification of which recipe binding contributed
to this message, used for per-component dropout.
The canonical recipe authors plan/memory/subtask injections with
distinctive prefix strings in the rendered content. Matching on
those prefixes is brittle if a future recipe author uses
different wording — but it's also localised to one place and
only affects the dropout fraction (never the actual semantics).
Returns ``None`` for messages we don't recognise; those are
always kept.
"""
content = message.get("content")
if isinstance(content, list):
text_parts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
t = block.get("text")
if isinstance(t, str):
text_parts.append(t)
content = "\n".join(text_parts)
if not isinstance(content, str):
return None
head = content.lstrip().lower()
if head.startswith("plan:") or head.startswith("previous plan"):
return "plan"
if head.startswith("memory:") or head.startswith("previous memory"):
return "memory"
if head.startswith("current subtask") or head.startswith("completed subtask"):
return "subtask"
return None
def _as_token_ids(value: Any) -> list[int]:
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
value = value["input_ids"]
if hasattr(value, "tolist"):
value = value.tolist()
if isinstance(value, list) and value and isinstance(value[0], list):
value = value[0]
return [int(i) for i in value]
# Re-export for tests / introspection
strip_lerobot_blocks = _strip_lerobot_blocks
flatten_say_tool_calls = _flatten_say_tool_calls
@@ -1,163 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from ..smolvla.configuration_smolvla import SmolVLAConfig
@PreTrainedConfig.register_subclass("smolvla2")
@dataclass
class SmolVLA2Config(SmolVLAConfig):
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
SmolVLA strips the LM head from the SmolVLM backbone because it only
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
same model can train on:
* **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss
on the action expert, same as SmolVLA. ``predict_actions=True``.
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
``user_interjection_response`` / ``high_level_subtask``) — cross-
entropy loss on the LM head over the recipe's target message tokens.
Skips the flow head entirely. ``predict_actions=False``.
* **mixed sub-recipes** — both heads run, losses summed (weighted).
The split is controlled by ``predict_actions = bool(targets_by_stream
.get("low_level"))`` per the Pi0.5 convention in the steerable
annotation plan (Section I.7), implemented inside the processor /
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
Compared to ``SmolVLAConfig`` this adds:
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
task-only path so unannotated datasets still work.
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
both losses are active in a single sample.
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn —
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
need to undo that for SmolVLA2.
- ``train_expert_only=False`` by default, since the VLM body now also
participates in text-target gradients.
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
rendering and fall back to SmolVLA's single-task prompt path
(unannotated datasets keep working that way)."""
apply_chat_template: bool = True
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
matches its training distribution."""
# Loss weights --------------------------------------------------------
# Pi 0.5 paper §IV.D (Eq. 1) sets α = 10 between the text-CE term
# and the flow-MSE term: L = H(text) + α * ‖ω - a - f_θ‖². The
# rationale is that actions are the primary output and the flow
# head should dominate the gradient signal; text is supervised as
# an auxiliary task and its CE scale (~0.5-2.0 in nats) tends to
# be larger than the flow MSE scale (~0.1-1.0), so without
# up-weighting the action head gets starved. We use a milder
# split (5:1) than the paper's α=10: ~40% of the blend is the
# flow-only ``low_level`` recipe, so the flow term already fires
# often, and α=10 starved the text head into degenerate decoding.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
flow_loss_weight: float = 5.0
"""Weight on the action-expert flow-matching term. Default 5.0 — a
milder split than the Pi 0.5 paper's α=10 (§IV.D), since the
flow-only ``low_level`` recipe already gives the action expert
frequent gradient. Set lower if the text head is underfitting
relative to the action expert; set higher if the action expert is
degrading because text loss dominates."""
# Optimizer -----------------------------------------------------------
optimizer_lr: float = 2.5e-5
"""Peak learning rate. Overrides ``SmolVLAConfig``'s ``1e-4``.
SmolVLA can afford ``1e-4`` because it *freezes* the language head —
only the from-scratch action expert sees that LR. SmolVLA2 unfreezes
``lm_head`` + the last text layer and fine-tunes the **pretrained**
SmolVLM2 language weights, and ``1e-4`` is too aggressive for a
pretrained LM: it destabilises the language representations and
collapses generation into degenerate repetition. ``2.5e-5`` matches
pi05's peak LR (openpi ``CosineDecaySchedule``), the comparable
text-co-trained policy. The action expert trains slightly slower at
this LR, so budget more steps."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
preceding norm + last text-model layer that SmolVLA freezes). Must be
``True`` for the text head to learn. Setting this to ``False``
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
which is occasionally useful for ablations."""
load_vlm_weights: bool = True
"""Load the pretrained SmolVLM2 backbone weights (vision tower +
language model + ``lm_head``) instead of random-initialising them.
``SmolVLAConfig`` defaults this to ``False`` because the original
SmolVLA pre-training run trained the VLM body itself. For SmolVLA2
that default is a footgun: the text head **is** the SmolVLM2
``lm_head``, and the high-level subtask supervision is hopeless if
it starts from a random language model — it can only memorise.
SmolVLA2 therefore defaults this to ``True`` so every run fine-tunes
from the pretrained ``vlm_model_name`` checkpoint
(``HuggingFaceTB/SmolVLM2-500M-Video-Instruct``).
Note this loads the *VLM backbone* pretrained; the action expert
still trains from scratch on the robot data (standard SmolVLA
fine-tuning). To also start the action expert from pretrained
weights, fine-tune from a full ``lerobot/smolvla_base`` checkpoint
via ``--policy.path``."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# At training, randomly drop non-target context messages whose
# content was substituted from the named recipe binding. Forces
# the model to handle missing context — directly attacks the
# memorisation collapse where a stale or missing plan/memory at
# inference puts the prompt out-of-distribution and the LM head
# falls back to dominant-mode fragments. All default to 0.0 so
# behaviour is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Plan:`` or ``Previous plan``
with this probability per sample."""
memory_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Memory:`` or ``Previous memory``
with this probability per sample."""
subtask_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Current subtask`` or
``Completed subtask`` with this probability per sample."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through its text path when the
# LM head is producing supervised text. Override the SmolVLA
# default (`train_expert_only=True`) unless the user explicitly
# opts out of text training via `text_loss_weight=0`.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
# The user can still flip this back via CLI; this only
# changes the *default* when SmolVLA2 is actually training a
# text head.
self.train_expert_only = False
@@ -1,692 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
Adds:
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
* a forward path that runs the flow head, the text head, or both,
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
this branch).
Per-sample routing — within one batch:
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
loss (action chunk supervision).
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
contribute to the LM-head cross-entropy.
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
datasets keep working unchanged.
"""
from __future__ import annotations
import math
from typing import Any
import torch
import torch.nn.functional as F
from torch import Tensor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
)
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
from .configuration_smolvla2 import SmolVLA2Config
def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, int]:
"""Find ``[lang_start, lang_end)`` inside the SmolVLA prefix.
``embed_prefix`` lays out the prefix as
``[image_blocks..., lang, state, padding]`` with the att-mask
convention ``[0]*image, [0]*lang, [1]*state, [0]*padding`` (see
``modeling_smolvla.SmolVLAModel.embed_prefix``). State is exactly
one token, and it's the *only* position with ``att_mask == 1``,
so we use the first ``1`` to anchor lang_end. Computing it this
way is robust to (a) state being projected to one embedding token
regardless of its raw feature dim, and (b) the trailing padding
added when ``seq_len < prefix_length``.
"""
row = prefix_att_masks[0]
ones = row.nonzero(as_tuple=False)
if ones.numel() == 0:
raise RuntimeError(
"SmolVLA2: state token not found in prefix att_masks — "
"can't locate language range."
)
state_start = int(ones[0, 0].item())
lang_end = state_start
lang_start = lang_end - num_lang
if lang_start < 0:
raise RuntimeError(
f"SmolVLA2: lang range underflows prefix "
f"(state_start={state_start}, num_lang={num_lang})."
)
return lang_start, lang_end
def _mark_target_span_causal(
prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
) -> Tensor:
"""Make the supervised text-target span causally masked.
``embed_prefix`` flags every language token with ``att=0``, which
``make_att_2d_masks`` turns into one fully *bidirectional* block —
so a target token's hidden state attends to the very tokens it is
supposed to predict. The text cross-entropy then degenerates into
a copy task (loss → ~0) and the model never learns causal
next-token prediction — at inference, where ``select_message``
decodes autoregressively (causally), it collapses.
Fix: set ``att=1`` on the language positions that are supervised
targets (``text_labels != -100``). With ``make_att_2d_masks``'s
cumulative-block rule each target token then attends to images +
the user prompt bidirectionally and to *earlier* target tokens
only — i.e. genuine causal next-token prediction, matching
inference. Non-target language (the user prompt, and the
``low_level_execution`` subtask which is a user turn, not a
target) stays ``att=0`` / bidirectional. The action expert is
unaffected: the suffix has a strictly higher cumsum so it still
attends to every prefix token.
"""
att = prefix_att_masks.clone()
n = min(text_labels.shape[1], lang_end - lang_start)
if n <= 0:
return att
target = text_labels[:, :n] != -100 # (B, n) bool
seg = att[:, lang_start : lang_start + n].bool()
att[:, lang_start : lang_start + n] = seg | target
return att
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
num_lang = logits.shape[1]
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid = shift_labels != -100
loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
return loss / valid.sum().clamp(min=1)
class SmolVLA2Policy(SmolVLAPolicy):
"""SmolVLA + re-enabled SmolVLM language head."""
config_class = SmolVLA2Config
name = "smolvla2"
def __init__(self, config: SmolVLA2Config, **kwargs):
if not isinstance(config, SmolVLA2Config):
config = SmolVLA2Config(
**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
if hasattr(config, f.name)
}
)
super().__init__(config, **kwargs)
if config.unfreeze_lm_head and config.text_loss_weight > 0:
self._unfreeze_lm_head()
# ------------------------------------------------------------------
# Backbone surgery
# ------------------------------------------------------------------
def _unfreeze_lm_head(self) -> None:
"""Re-enable gradients on the text-output path so the LM head
loss can flow back.
SmolVLA's ``set_requires_grad`` freezes three things when
``train_expert_only=False``: ``lm_head``,
``text_model.model.norm.weight``, and the last 1-2 text-model
transformer layers (see ``smolvlm_with_expert.py:167-176``).
We must unfreeze *all three* — otherwise gradients still die
in the frozen final block and the lm_head learns nothing.
"""
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
if vlm_with_expert is None:
return
vlm = getattr(vlm_with_expert, "vlm", None)
if vlm is None:
return
# Mirror the freeze targets from ``smolvlm_with_expert.set_requires_grad``.
num_vlm = getattr(vlm_with_expert, "num_vlm_layers", None)
num_expert = getattr(vlm_with_expert, "num_expert_layers", None)
last_layers = []
if num_vlm is not None:
last_layers.append(num_vlm - 1)
if (
num_expert is not None
and num_vlm != num_expert
and num_vlm % num_expert == 0
):
last_layers.append(num_vlm - 2)
unfreeze_prefixes = [
"lm_head",
"text_model.model.norm.weight",
*[f"text_model.model.layers.{layer}." for layer in last_layers],
]
for name, param in vlm.named_parameters():
if any(k in name for k in unfreeze_prefixes):
param.requires_grad = True
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
batch: dict[str, Tensor],
noise: Tensor | None = None,
time: Tensor | None = None,
reduction: str = "mean",
) -> tuple[Tensor, dict[str, Any]]:
"""Forward pass with optional dual-head loss.
Two routing knobs from the batch (produced by
:class:`SmolVLA2ChatTokenizerStep`):
* ``text_labels`` — per-token labels with ``-100`` for non-target
positions. Triggers the text-loss path through ``lm_head``.
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
include this sample's action chunk in the flow loss.
When neither is present, delegate to ``SmolVLAPolicy.forward``.
"""
text_labels = batch.get("text_labels")
predict_actions_t = batch.get("predict_actions")
has_text_data = (
text_labels is not None
and isinstance(text_labels, Tensor)
and self.config.text_loss_weight > 0
)
has_per_sample_routing = (
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
)
if not has_text_data and not has_per_sample_routing:
return super().forward(batch, noise=noise, time=time, reduction=reduction)
loss_dict: dict[str, Any] = {}
device = batch[OBS_STATE].device
total = torch.zeros((), device=device, dtype=torch.float32)
run_flow = (
self.config.flow_loss_weight > 0
and ACTION in batch
and (not has_per_sample_routing or bool(predict_actions_t.any().item()))
)
# ------------------------------------------------------------
# Fused path — one backbone forward for flow + text together.
# ------------------------------------------------------------
if run_flow and has_text_data:
flow_loss, text_loss, flow_diag = self._compute_fused_loss(
batch, text_labels, predict_actions_t, noise=noise, time=time
)
total = total + self.config.flow_loss_weight * flow_loss
total = total + self.config.text_loss_weight * text_loss
loss_dict["flow_loss"] = float(flow_loss.detach().item())
loss_dict["text_loss"] = float(text_loss.detach().item())
for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v
elif run_flow:
per_sample_flow, flow_diag = super().forward(
batch, noise=noise, time=time, reduction="none"
)
if has_per_sample_routing:
mask = predict_actions_t.to(per_sample_flow.dtype)
masked = per_sample_flow * mask
denom = mask.sum().clamp(min=1.0)
flow_loss = masked.sum() / denom
else:
flow_loss = per_sample_flow.mean()
total = total + self.config.flow_loss_weight * flow_loss
loss_dict["flow_loss"] = float(flow_loss.detach().item())
for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v
elif has_text_data:
text_loss = self._compute_text_loss(batch, text_labels)
total = total + self.config.text_loss_weight * text_loss
loss_dict["text_loss"] = float(text_loss.detach().item())
else:
# No path fired — happens when both loss weights are 0 or
# the batch has neither action samples nor supervised text.
# Fail loud rather than train silently on a zero loss.
raise RuntimeError(
"SmolVLA2Policy.forward: nothing to train — "
"flow_loss_weight=%s, text_loss_weight=%s, "
"predict_actions.any()=%s, has_text_data=%s"
% (
self.config.flow_loss_weight,
self.config.text_loss_weight,
bool(predict_actions_t.any().item()) if has_per_sample_routing else None,
has_text_data,
)
)
loss_dict["loss"] = float(total.detach().item())
if reduction == "none":
# Per-sample loss isn't meaningfully defined for the dual
# path; broadcast the scalar to (B,) for caller compat.
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
return total, loss_dict
# ------------------------------------------------------------------
# Text-loss internals
# ------------------------------------------------------------------
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
# Causally mask the supervised target span so the text-CE is
# genuine next-token prediction (see ``_mark_target_span_causal``).
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_att_masks = _mark_target_span_causal(
prefix_att_masks, text_labels, lang_start, lang_end
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Prefix-only forward.
out_pair, _ = self.model.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=True,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError(
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
"states — text-loss path needs them."
)
# ``lang_start`` / ``lang_end`` were located above on the
# *unmodified* att masks — don't recompute here, because
# ``_mark_target_span_causal`` set target lang tokens to 1 and
# ``_locate_lang_range`` keys on the first 1 (the state token).
vlm = self.model.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
return _shifted_ce(logits, text_labels)
# ------------------------------------------------------------------
# Fused flow + text loss (single backbone forward)
# ------------------------------------------------------------------
def _compute_fused_loss(
self,
batch: dict[str, Tensor],
text_labels: Tensor,
predict_actions_t: Tensor | None,
noise: Tensor | None = None,
time: Tensor | None = None,
) -> tuple[Tensor, Tensor, dict[str, Any]]:
"""One backbone forward → both flow MSE and text CE.
Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one
``vlm_with_expert`` call) but captures **both** outputs:
* ``prefix_out[:, lang_start:lang_end]`` → ``lm_head`` → CE on
``text_labels`` (same slicing as ``_compute_text_loss``).
* ``suffix_out[:, -chunk_size:]`` → ``action_out_proj`` → flow
MSE against ``noise - actions`` (same as the parent forward).
Saves one backbone pass per training step vs. running the flow
and text paths separately — same trick PI052Policy uses in
``_compute_all_losses_fused``.
"""
cfg = self.config
if cfg.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
actions = self.prepare_action(batch)
inner = self.model
if noise is None:
noise = inner.sample_noise(actions.shape, actions.device)
if time is None:
time = inner.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
# Causally mask the supervised text-target span (see
# ``_mark_target_span_causal``). Per-sample: high_level_subtask
# samples have a subtask target → causal; low_level_execution
# samples have all -100 labels → untouched / bidirectional, so
# the action expert still reads the subtask as bidirectional
# context. ``lang_start`` / ``lang_end`` located here on the
# unmodified mask and reused for the text-loss slice below.
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_att_masks = _mark_target_span_causal(
prefix_att_masks, text_labels, lang_start, lang_end
)
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
out_pair, _ = inner.vlm_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
prefix_out, suffix_out = out_pair[0], out_pair[1]
if prefix_out is None or suffix_out is None:
raise RuntimeError(
"SmolVLA2: fused forward expected both prefix and suffix "
"hidden states from vlm_with_expert."
)
# ---------------- flow loss (per-sample maskable) ----------------
chunk = cfg.chunk_size
suffix_chunk = suffix_out[:, -chunk:].to(torch.float32)
v_t = inner.action_out_proj(suffix_chunk)
losses = F.mse_loss(u_t, v_t, reduction="none")
original_action_dim = cfg.action_feature.shape[0]
losses = losses[:, :, :original_action_dim]
flow_diag = {"losses_after_forward": float(losses.detach().mean().item())}
actions_is_pad = batch.get("action_is_pad")
if actions_is_pad is not None:
in_episode = ~actions_is_pad
losses = losses * in_episode.unsqueeze(-1)
flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item())
losses = losses[:, :, : cfg.max_action_dim]
flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item())
per_sample_flow = losses.mean(dim=(1, 2))
if predict_actions_t is not None:
mask = predict_actions_t.to(per_sample_flow.dtype)
flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0)
else:
flow_loss = per_sample_flow.mean()
# ---------------- text loss (lang slice of prefix) ---------------
# ``lang_start`` / ``lang_end`` from above (unmodified mask).
vlm = inner.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden)
text_loss = _shifted_ce(logits, text_labels)
return flow_loss, text_loss, flow_diag
# ------------------------------------------------------------------
# Inference: text generation
# ------------------------------------------------------------------
@torch.no_grad()
def select_message(
self,
batch: dict[str, Tensor],
*,
max_new_tokens: int = 256,
min_new_tokens: int = 0,
eos_token_id: int | None = None,
temperature: float = 0.0,
top_p: float = 1.0,
tokenizer: Any = None,
) -> str:
"""Generate text continuation from the chat-templated prompt.
AR decoding with KV caching reused from SmolVLA's inference
path. Batch size is assumed to be 1 (the runtime calls this
per-event). Returns the decoded string of new tokens (the
prompt itself is not included).
Parameters
----------
batch:
Already through the SmolVLA2 preprocessor — expects
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
``OBS_LANGUAGE_ATTENTION_MASK``.
max_new_tokens:
Hard cap on generated tokens; stops earlier on EOS.
eos_token_id:
Override the tokenizer's EOS. ``None`` ⇒ use the
tokenizer's default.
temperature, top_p:
``temperature=0`` does greedy argmax (default — matches
training distribution most closely). Set ``temperature>0``
with optional ``top_p<1`` for nucleus sampling.
tokenizer:
Optional pre-loaded tokenizer to avoid the cold-start
``AutoTokenizer.from_pretrained`` round-trip on every call.
"""
self.eval()
if tokenizer is None:
from transformers import AutoTokenizer # noqa: PLC0415
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
if eos_token_id is None:
eos_token_id = tokenizer.eos_token_id
# Build the full set of special-token ids to suppress during
# the ``min_new_tokens`` window. EOS alone is not enough on a
# memorised SmolVLM head — when EOS is masked, the argmax
# falls onto a sibling special token (``<|im_end|>``,
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
# …) which then survives generation but gets stripped by
# ``skip_special_tokens=True`` so ``decode`` returns an empty
# string and the runtime sees ``last_raw='(empty)'`` every
# chunk boundary.
special_ids_set: set[int] = set()
try:
for sid in (tokenizer.all_special_ids or []):
if sid is not None:
special_ids_set.add(int(sid))
except Exception: # noqa: BLE001
pass
if eos_token_id is not None:
special_ids_set.add(int(eos_token_id))
# Match training's text-loss forward path (see
# ``_compute_text_loss`` above): build the full prefix via
# ``embed_prefix`` so images + state conditioning is intact,
# then loop AR with ``fill_kv_cache=True, use_cache=False``.
# That flag combo routes every layer through
# ``forward_attn_layer`` (which gracefully skips ``None``
# expert inputs via ``if hidden_states is None or layer is
# None: continue``) and short-circuits the cache-update logic
# so we don't have to manage past_kv. Each step just
# re-forwards the cumulative ``[prefix + generated]``
# sequence.
#
# This is O(n²) in generated text length but cheap in
# absolute terms: image encoding happens once via the initial
# ``embed_prefix`` call, and the per-step cost is just one
# SmolVLM transformer pass over a sequence that grows by one
# token each time. Standard SmolVLM ``generate`` was the
# other tempting path, but it can't accept SmolVLA's custom
# ``state_proj`` output and its tile-grid expectations
# disagree with our preprocessor — both lead to garbage
# generation, which is what the prior approach produced.
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
# ``embed_prefix`` lays the prefix out as ``[images, lang, state]``
# — the state token is LAST. Training supervises the text head on
# the *language* positions (see ``_compute_text_loss`` /
# ``_compute_fused_loss``: lm_head over ``prefix_out[lang_start:
# lang_end]``). So AR text generation must continue from the last
# language token (right after the ``Assistant:`` generation
# prompt) — NOT from the state token, whose hidden state exists
# for the action expert to read and which the lm_head was never
# trained to decode subtask text from. Truncating the state token
# here makes ``prefix_out[:, -1:]`` in the loop below the last
# language position, matching the training distribution.
_, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
prefix_embs = prefix_embs[:, :lang_end]
prefix_pad_masks = prefix_pad_masks[:, :lang_end]
prefix_att_masks = prefix_att_masks[:, :lang_end]
device = prefix_embs.device
bsize = prefix_embs.shape[0]
vlm = self.model.vlm_with_expert.vlm
emb_dim = prefix_embs.shape[-1]
text_emb_scale = math.sqrt(emb_dim)
current_embs = prefix_embs
current_pad = prefix_pad_masks
current_att = prefix_att_masks
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
generated: list[int] = []
for _ in range(max_new_tokens):
full_2d = make_att_2d_masks(current_pad, current_att)
full_pos = torch.cumsum(current_pad, dim=1) - 1
out_pair, _ = self.model.vlm_with_expert.forward(
attention_mask=full_2d,
position_ids=full_pos,
past_key_values=None,
inputs_embeds=[current_embs, None],
use_cache=False,
fill_kv_cache=True,
)
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
if prefix_out is None:
raise RuntimeError(
"select_message: vlm_with_expert.forward returned no hidden states."
)
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
# Suppress *all* special tokens until we've decoded
# ``min_new_tokens`` real (renderable) tokens. Without
# this, a memorised SmolVLM head whose argmax at position
# 0 is a special token produces an empty completion every
# time — either EOS directly, or (after we mask EOS) the
# argmax shifts to a sibling special id (``<|im_end|>``,
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
# via ``skip_special_tokens=True``. Masking the full
# ``all_special_ids`` set for the first N steps forces
# the head to commit to a normal vocabulary token before
# it can close (or quietly poison) the turn.
if special_ids_set and len(generated) < min_new_tokens:
for sid in special_ids_set:
logits_step[..., sid] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
if eos_token_id is not None and tok_id == eos_token_id:
break
new_emb = self.model.vlm_with_expert.embed_language_tokens(
next_ids.unsqueeze(0)
)
new_emb = new_emb * text_emb_scale
current_embs = torch.cat([current_embs, new_emb], dim=1)
current_pad = torch.cat([current_pad, ones_step], dim=1)
current_att = torch.cat([current_att, ones_step], dim=1)
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
# When the visible decoded string is empty but tokens *were*
# generated, expose what those raw tokens decoded to without
# the special-token filter. This is what the runtime turns
# into a scrollback line when ``last_raw='(empty)'`` so the
# operator can tell whether the head is emitting EOS, image
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
# or something else.
if not decoded and generated:
try:
self._last_select_message_debug = (
f"raw_ids={generated[:16]} "
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
)
except Exception: # noqa: BLE001
self._last_select_message_debug = f"raw_ids={generated[:16]}"
else:
self._last_select_message_debug = ""
return decoded
@staticmethod
def _sample_next_token(
logits: Tensor, temperature: float, top_p: float
) -> Tensor:
"""Pick one token id per batch row from ``logits``."""
if temperature <= 0.0:
return logits.argmax(dim=-1)
scaled = logits / max(temperature, 1e-6)
probs = F.softmax(scaled, dim=-1)
if top_p < 1.0:
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
cum = sorted_probs.cumsum(dim=-1)
mask = cum > top_p
# Always keep the most-likely token.
mask[..., 0] = False
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
pick = torch.multinomial(sorted_probs, num_samples=1)
return sorted_idx.gather(-1, pick).squeeze(-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
@@ -1,134 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SmolVLA2 processor pipelines.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
RenderMessagesStep(recipe) # PR 1: language_* → messages
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
DeviceProcessorStep
NormalizerProcessorStep
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
plain task-string pipeline so unannotated datasets still work.
Post-processor is unchanged from SmolVLA.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
RenderMessagesStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
from .configuration_smolvla2 import SmolVLA2Config
def make_smolvla2_pre_post_processors(
config: SmolVLA2Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build SmolVLA2's pre/post-processor pipelines.
With ``recipe_path`` set, inserts the recipe-rendering step and the
chat-template tokenizer that emits ``text_labels`` and
``predict_actions`` for the dual-loss path. Without it, falls back
to SmolVLA's plain task-string pipeline so unannotated datasets
keep working unchanged.
"""
if not config.recipe_path:
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
RenderMessagesStep(recipe=recipe),
SmolVLA2ChatTokenizerStep(
tokenizer_name=config.vlm_model_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/`` so recipe authors can write
``--policy.recipe_path=recipes/subtasks_vqa.yaml``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)
@@ -12,10 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
"""``lerobot-pi052-runtime`` — interactive REPL for trained PI052.
Drives the multi-rate runtime defined in
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
:mod:`lerobot.policies.pi052.inference`. Stdin becomes the user
channel: type a task, then natural-language interjections / questions.
The runtime prints state changes (plan / subtask / memory / vqa /
speech) as they happen.
@@ -26,16 +26,16 @@ Examples
Dry run on a Hub checkpoint, no robot connected useful for sanity-
checking text generation::
uv run lerobot-smolvla2-runtime \\
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
uv run lerobot-pi052-runtime \\
--policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\
--no_robot \\
--task="please clean the kitchen"
Same, but feed real frames from an annotated dataset so plan / subtask
/ memory / VQA generation runs against actual video + state::
uv run lerobot-smolvla2-runtime \\
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
uv run lerobot-pi052-runtime \\
--policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\
--dataset.repo_id=pepijn223/super_poulain_annotated \\
--dataset.episode=0 \\
--no_robot \\
@@ -43,7 +43,7 @@ Same, but feed real frames from an annotated dataset so plan / subtask
With a real robot::
uv run lerobot-smolvla2-runtime \\
uv run lerobot-pi052-runtime \\
--policy.path=... \\
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
--tts.voice=alba
@@ -62,18 +62,14 @@ import logging
import sys
from typing import Any, Callable
logger = logging.getLogger("lerobot.smolvla2.runtime")
logger = logging.getLogger("lerobot.pi052.runtime")
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
# prog defaults to the invoked command name, so this reads
# correctly whether run as lerobot-smolvla2-runtime or
# lerobot-pi052-runtime.
description=(
"Interactive REPL runtime for a trained hierarchical VLA "
"checkpoint (SmolVLA2 or PI052 — policy type is read from "
"the checkpoint)."
"Interactive REPL runtime for a trained PI052 hierarchical "
"VLA checkpoint."
),
)
p.add_argument(
@@ -82,7 +78,7 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
type=str,
required=True,
help=(
"Local directory or Hugging Face Hub repo id pointing at a trained SmolVLA2 ``pretrained_model``."
"Local directory or Hugging Face Hub repo id pointing at a trained PI052 ``pretrained_model``."
),
)
p.add_argument(
@@ -368,7 +364,7 @@ def _load_policy_and_preprocessor(
policy_path: str,
dataset_repo_id: str | None,
) -> tuple[Any, Any, Any, Any]:
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
"""Load a PI052 checkpoint (local path or Hub repo id).
Returns ``(policy, preprocessor, postprocessor, ds_meta)``.
``preprocessor`` / ``postprocessor`` / ``ds_meta`` are ``None``
@@ -430,7 +426,7 @@ def _build_observation_provider(
The dataset's ``language_persistent`` / ``language_events``
columns are stripped before the sample reaches the preprocessor,
so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are
so ``RenderMessagesStep`` and ``PI052TextTokenizerStep`` are
no-ops; the runtime supplies its own messages from current state.
"""
import torch # noqa: PLC0415
@@ -613,7 +609,7 @@ def _select_task_interactively(
# bootstrap default (may be None — REPL handles that).
return bootstrap_task
print("\n[smolvla2] Select startup task:", flush=True)
print("\n[pi052] Select startup task:", flush=True)
if options:
for i, opt in enumerate(options, 1):
marker = " (dataset default)" if opt == bootstrap_task else ""
@@ -816,7 +812,7 @@ def _build_robot_observation_provider(
# ``ds_features``. The training distribution sees frames at the
# recorded resolution (e.g. 480×640); a live Mac/USB camera will
# almost always hand us a different native size (720p / 1080p).
# SmolVLA's internal ``resize_with_pad(512, 512)`` does pad the
# PI052's internal ``resize_with_pad(512, 512)`` does pad the
# input to a fixed canvas, but the *geometry* of that pad differs
# by input aspect ratio — top/left padding varies, so the visual
# tokens at each tile carry different content than what the model
@@ -1017,7 +1013,7 @@ def _build_robot_action_executor(
def _print_runtime_help() -> None:
"""Print the slash-command reference."""
print(
"[smolvla2] commands (arguments need no quotes):\n"
"[pi052] commands (arguments need no quotes):\n"
" /action <task> run the robot; an argument switches to that task\n"
" /action resume the robot on the current task\n"
" /action <seconds> run the robot for N seconds, then auto-pause\n"
@@ -1085,7 +1081,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
secs = float(rest)
runtime.state["action_deadline"] = _time.monotonic() + secs
print(
f"[smolvla2] action — running {secs:g}s, then auto-pause",
f"[pi052] action — running {secs:g}s, then auto-pause",
flush=True,
)
else:
@@ -1095,16 +1091,16 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
# New task → drop the stale subtask so the high-level
# loop regenerates one for the new goal.
runtime.state["current_subtask"] = None
print(f"[smolvla2] action — task: {rest!r}", flush=True)
print(f"[pi052] action — task: {rest!r}", flush=True)
elif runtime.state.get("task"):
print(
f"[smolvla2] action — resuming: {runtime.state['task']!r}",
f"[pi052] action — resuming: {runtime.state['task']!r}",
flush=True,
)
else:
runtime.state["mode"] = "paused"
print(
"[smolvla2] no task set — use /action <your task>",
"[pi052] no task set — use /action <your task>",
flush=True,
)
return True
@@ -1113,7 +1109,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
runtime.state["mode"] = "paused"
runtime.state["action_deadline"] = None
_clear_action_queue(runtime)
print("[smolvla2] paused — robot holding position", flush=True)
print("[pi052] paused — robot holding position", flush=True)
return True
if cmd in {"/question", "/q", "/ask", "/vqa", "/vlm"}:
@@ -1124,7 +1120,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
_clear_action_queue(runtime)
if not rest:
print(
"[smolvla2] usage: /question <your question> "
"[pi052] usage: /question <your question> "
"(e.g. /question point to the yellow cube)",
flush=True,
)
@@ -1144,7 +1140,7 @@ def _run_vqa_query(runtime: Any, question: str) -> None:
Invoked by ``/question`` the action loop is paused first so the
policy is free for a synchronous VQA call.
"""
from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415
from lerobot.policies.pi052.inference.vqa import handle_vqa_query # noqa: PLC0415
handle_vqa_query(
policy=runtime.policy,
@@ -1180,11 +1176,11 @@ def _run_autonomous(
if not auto_start and runtime.state.get("mode", "paused") == "action":
try:
input(
"[smolvla2] Robot connected — starting in ACTION mode. "
"[pi052] Robot connected — starting in ACTION mode. "
"Press ENTER to begin, Ctrl+C to abort. "
)
except (EOFError, KeyboardInterrupt):
print("\n[smolvla2] aborted before start", flush=True)
print("\n[pi052] aborted before start", flush=True)
return 130
if initial_task:
@@ -1193,7 +1189,7 @@ def _run_autonomous(
thread = threading.Thread(
target=runtime.run,
kwargs={"max_ticks": max_ticks},
name="smolvla2-runtime-loop",
name="pi052-runtime-loop",
daemon=True,
)
thread.start()
@@ -1251,7 +1247,7 @@ def _run_autonomous(
if hasattr(queue, "clear"):
queue.clear()
print(
"\n[smolvla2] timed action elapsed — paused",
"\n[pi052] timed action elapsed — paused",
flush=True,
)
else:
@@ -1264,7 +1260,7 @@ def _run_autonomous(
pass
_panel_stop.wait(0.7)
panel_thread = threading.Thread(target=_panel_loop, name="smolvla2-panel-redraw", daemon=True)
panel_thread = threading.Thread(target=_panel_loop, name="pi052-panel-redraw", daemon=True)
panel_thread.start()
try:
@@ -1296,11 +1292,11 @@ def _run_autonomous(
runtime.state.setdefault("events_this_tick", []).append("user_interjection")
else:
print(
"[smolvla2] no task yet — use /action <your task> to start",
"[pi052] no task yet — use /action <your task> to start",
flush=True,
)
except KeyboardInterrupt:
print("\n[smolvla2] interrupt — stopping", flush=True)
print("\n[pi052] interrupt — stopping", flush=True)
finally:
_panel_stop.set()
runtime.stop()
@@ -1311,9 +1307,9 @@ def _run_autonomous(
time.sleep(0.1)
try:
robot.disconnect()
print("[smolvla2] robot disconnected", flush=True)
print("[pi052] robot disconnected", flush=True)
except Exception as exc: # noqa: BLE001
print(f"[smolvla2] WARNING: robot.disconnect raised {exc}", flush=True)
print(f"[pi052] WARNING: robot.disconnect raised {exc}", flush=True)
return 0
@@ -1340,7 +1336,7 @@ def _make_state_panel_renderer(
st = runtime.state
run_mode = st.get("mode", "action")
mode_tag = "[green]mode: action[/]" if run_mode == "action" else "[yellow]mode: paused[/]"
console.rule(f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan")
console.rule(f"[bold]PI052[/] · {mode_label} · {mode_tag}", style="cyan")
# Always-visible command hint so the operator never has to
# remember the slash commands.
if run_mode == "action":
@@ -1499,14 +1495,14 @@ def main(argv: list[str] | None = None) -> int:
autonomous_mode = bool(args.robot_type) and not args.no_robot
if autonomous_mode and not args.dataset_repo_id:
print(
"[smolvla2] ERROR: autonomous robot mode requires --dataset.repo_id "
"[pi052] ERROR: autonomous robot mode requires --dataset.repo_id "
"for action-denormalisation stats and feature shapes. Pass the "
"same dataset the policy was trained on.",
file=sys.stderr,
)
return 2
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
print(f"[pi052] loading policy from {args.policy_path}", flush=True)
policy, preprocessor, postprocessor, ds_meta = _load_policy_and_preprocessor(
args.policy_path, args.dataset_repo_id
)
@@ -1537,7 +1533,7 @@ def main(argv: list[str] | None = None) -> int:
)
if chosen:
args.task = chosen
print(f"[smolvla2] task: {args.task!r}", flush=True)
print(f"[pi052] task: {args.task!r}", flush=True)
# No startup prompts — the runtime is command-driven. It comes up at
# the command line in ``paused`` mode (robot idle) unless ``--mode``
@@ -1551,7 +1547,7 @@ def main(argv: list[str] | None = None) -> int:
if autonomous_mode:
print(
f"[smolvla2] connecting to robot.type={args.robot_type} port={args.robot_port}",
f"[pi052] connecting to robot.type={args.robot_type} port={args.robot_port}",
flush=True,
)
robot = _build_robot(
@@ -1575,7 +1571,7 @@ def main(argv: list[str] | None = None) -> int:
)
elif args.dataset_repo_id is not None:
print(
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
f"[pi052] streaming observations from {args.dataset_repo_id} "
f"episode={args.dataset_episode} "
f"start_frame={args.dataset_start_frame}",
flush=True,
@@ -1592,11 +1588,11 @@ def main(argv: list[str] | None = None) -> int:
tools = _build_tools(args.no_tts, args.tts_voice)
if tools:
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
print(f"[pi052] tools loaded: {list(tools)}", flush=True)
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
from lerobot.policies.pi052.inference import PI052Runtime # noqa: PLC0415
runtime = SmolVLA2Runtime(
runtime = PI052Runtime(
policy=policy,
tools=tools,
observation_provider=observation_provider,
@@ -1662,7 +1658,7 @@ def main(argv: list[str] | None = None) -> int:
logger.warning("startup tick failed: %s", exc)
startup_logs = []
for line in startup_logs or []:
print(f"[smolvla2] {line}", flush=True)
print(f"[pi052] {line}", flush=True)
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
@@ -1680,7 +1676,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
from rich.console import Console # noqa: PLC0415
except ImportError:
print(
"[smolvla2] rich is required for the interactive REPL. `pip install rich` and re-run.",
"[pi052] rich is required for the interactive REPL. `pip install rich` and re-run.",
file=sys.stderr,
)
return 2
@@ -1724,7 +1720,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
# task to be meaningful.
if not runtime.state.get("task"):
print(
"[smolvla2] no task yet — use /action <your task>",
"[pi052] no task yet — use /action <your task>",
flush=True,
)
_redraw(last_logs)
+3 -4
View File
@@ -13,10 +13,9 @@
# limitations under the License.
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
runtime dispatchers consume this when the model emits an assistant
message with ``tool_calls=[{function: {name: "say", arguments:
{text: ...}}}]``.
The first concrete tool implementation. PI052 and downstream runtime
dispatchers consume this when the model emits an assistant message
with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``.
Why pocket-tts:
+2 -2
View File
@@ -162,7 +162,7 @@ def test_messages_vqa_to_loc_noop_without_target_indices():
def test_loc_round_trip_keypoint_preserves_normalized_coords():
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
loc = _vqa_answer_to_loc(answer)
@@ -175,7 +175,7 @@ def test_loc_round_trip_keypoint_preserves_normalized_coords():
def test_loc_round_trip_bbox_preserves_order_and_scale():
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
@@ -1,163 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention-masking tests for the SmolVLA2 text head.
Regression coverage for the text-CE collapse bug: ``embed_prefix`` flags
every language token ``att=0``, which ``make_att_2d_masks`` turns into a
single fully *bidirectional* block. Under that mask the text
cross-entropy degenerates into a copy task a supervised target token
attends to the tokens it is trained to predict and the model never
learns causal generation, so ``select_message`` collapses at inference.
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
language positions so each target token attends causally among the
targets while staying bidirectional to images + the user prompt. These
tests pin that behaviour.
"""
import pytest
import torch
# The smolvla2 modeling module imports transformers transitively.
pytest.importorskip("transformers")
from lerobot.policies.smolvla.modeling_smolvla import make_att_2d_masks # noqa: E402
from lerobot.policies.smolvla2.modeling_smolvla2 import ( # noqa: E402
_locate_lang_range,
_mark_target_span_causal,
)
# ---------------------------------------------------------------------------
# A synthetic SmolVLA prefix layout: [images, prompt-lang, target-lang, state]
#
# indices 0-1 : 2 image tokens (att = 0)
# indices 2-4 : 3 user-prompt lang (att = 0)
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
# index 9 : 1 state token (att = 1)
#
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
# real ids on the 4-token target span.
# ---------------------------------------------------------------------------
N_IMAGE = 2
N_PROMPT = 3
N_TARGET = 4
LANG_START = N_IMAGE
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = state-token index
PREFIX_LEN = LANG_END + 1
def _embed_prefix_att_masks() -> torch.Tensor:
"""Mimic ``embed_prefix``: images + lang all att=0, state att=1."""
att = torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
att[0, LANG_END] = True # the single state token
return att
def _text_labels() -> torch.Tensor:
"""-100 over the prompt span, real ids over the target span."""
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
return labels
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
return make_att_2d_masks(pad, prefix_att_masks)[0]
def test_locate_lang_range_anchors_on_state_token():
"""``_locate_lang_range`` finds the lang span via the lone att=1 token."""
lang_start, lang_end = _locate_lang_range(
_embed_prefix_att_masks(), num_lang=N_PROMPT + N_TARGET
)
assert (lang_start, lang_end) == (LANG_START, LANG_END)
def test_mark_sets_att_on_targets_only():
"""Only the supervised target language positions flip to att=1."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
expected = [False] * PREFIX_LEN
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
expected[i] = True
expected[LANG_END] = True # state token, untouched
assert marked[0].tolist() == expected
def test_target_tokens_attend_causally_among_themselves():
"""A target token must NOT attend to later targets, but must attend
to earlier ones i.e. genuine causal next-token prediction."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
tgt = range(LANG_START + N_PROMPT, LANG_END)
for i in tgt:
for j in tgt:
if j > i:
assert not attends[i, j], f"target {i} must not see future target {j}"
else:
assert attends[i, j], f"target {i} must see earlier/self target {j}"
def test_target_tokens_attend_prompt_and_images_bidirectionally():
"""Targets keep full visibility of images + the user prompt."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
for i in range(LANG_START + N_PROMPT, LANG_END):
for j in context:
assert attends[i, j], f"target {i} must attend context {j}"
def test_action_expert_token_still_sees_full_subtask():
"""The state token (action-expert context) attends to every target —
causal masking the targets must not hide them from the action path."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
for j in range(LANG_START + N_PROMPT, LANG_END):
assert attends[LANG_END, j], f"state token must see target {j}"
def test_non_target_subtask_stays_bidirectional():
"""``low_level_execution`` renders the subtask as a user turn — its
``text_labels`` are all -100, so the mask must be left untouched and
the action expert reads the subtask bidirectionally."""
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
)
assert torch.equal(marked, _embed_prefix_att_masks())
def test_unmarked_mask_is_bidirectional_the_bug():
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
a target token attends *bidirectionally* to later targets the
text-CE can copy the answer it is trained to predict."""
attends = _attends(_embed_prefix_att_masks())
first_tgt = LANG_START + N_PROMPT
last_tgt = LANG_END - 1
assert attends[first_tgt, last_tgt], (
"raw embed_prefix mask is bidirectional over language — the first "
"target token can see the last, which is the collapse bug"
)
@@ -1,77 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening.
``_split_plan_and_say`` (inference) expects the model to emit a textual
``<say>...</say>`` marker. ``_flatten_say_tool_calls`` is the training-time
serializer that produces it: it rewrites an assistant turn's structured
``say`` tool call into that marker *inside the content text*, before
``apply_chat_template`` runs so the chat template only tokenizes plain
text and the supervised target span trains the model to emit the marker
the runtime parses back. These tests pin the round-trip.
"""
from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls
from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say
def _say_call(text):
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
def test_flatten_appends_say_marker_and_drops_tool_calls():
msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]}
out = flatten_say_tool_calls(msg)
assert "tool_calls" not in out
assert out["content"] == "Pick up the blue cube.\n<say>On it!</say>"
def test_flatten_roundtrips_through_inference_parser():
"""The marker the serializer writes must be exactly what the inference
parser reads back this is the train/inference contract."""
msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]}
flat = flatten_say_tool_calls(msg)["content"]
plan, speech = _split_plan_and_say(flat)
assert plan == "Move toward the cube."
assert speech == "Working on it"
def test_flatten_accepts_json_string_arguments():
"""``arguments`` may arrive as a JSON string rather than a dict."""
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
assert out["content"] == "p\n<say>hello there</say>"
def test_flatten_leaves_messages_without_tool_calls_untouched():
msg = {"role": "assistant", "content": "just a plan"}
assert flatten_say_tool_calls(msg) == msg
def test_flatten_drops_empty_or_non_say_tool_calls():
"""A non-``say`` call (or empty text) leaves content alone but still
strips the structured calls so the template renders no JSON block."""
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]})
assert out["content"] == "plan only"
assert "tool_calls" not in out
def test_flatten_marker_only_when_content_empty():
msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]}
out = flatten_say_tool_calls(msg)
assert out["content"] == "<say>hi</say>"
@@ -1,228 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
Covers camera selection, VQA-answer parsing, and the bounding-box /
keypoint overlay drawing the pure functions, no model load.
"""
import numpy as np
import pytest
from lerobot.policies.smolvla2.inference.vqa import (
answer_has_overlay,
available_cameras,
camera_short_name,
draw_vqa_overlay,
observation_image_to_pil,
parse_vqa_answer,
prompt_camera_choice,
)
PIL = pytest.importorskip("PIL")
from PIL import Image # noqa: E402
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def test_available_cameras_extracts_and_sorts_image_keys():
observation = {
"observation.images.wrist": object(),
"observation.state": object(),
"observation.images.top": object(),
"task": "x",
}
assert available_cameras(observation) == [
"observation.images.top",
"observation.images.wrist",
]
def test_available_cameras_handles_none_and_empty():
assert available_cameras(None) == []
assert available_cameras({}) == []
def test_camera_short_name_strips_prefix():
assert camera_short_name("observation.images.top") == "top"
assert camera_short_name("top") == "top"
def test_prompt_camera_choice_single_camera_auto_selects():
cams = ["observation.images.top"]
# input_fn must never be called for a single-camera setup.
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_by_number():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
assert chosen == "observation.images.wrist"
def test_prompt_camera_choice_by_name():
cams = ["observation.images.top", "observation.images.wrist"]
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
assert chosen == "observation.images.top"
def test_prompt_camera_choice_invalid_returns_none():
cams = ["observation.images.top", "observation.images.wrist"]
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
def _boom(*_args, **_kwargs):
raise AssertionError("input_fn should not be called")
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def test_parse_bbox_answer():
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert answer_has_overlay(parsed)
def test_parse_keypoint_answer():
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "keypoint"
assert answer_has_overlay(parsed)
def test_parse_count_answer_is_not_an_overlay():
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
assert parsed["kind"] == "count"
assert not answer_has_overlay(parsed)
def test_parse_invalid_json_returns_none():
assert parse_vqa_answer("not json at all") is None
assert parse_vqa_answer("") is None
# A JSON array is valid JSON but not a VQA answer object.
assert parse_vqa_answer("[1, 2, 3]") is None
def test_parse_unknown_shape():
parsed = parse_vqa_answer('{"weird": "payload"}')
assert parsed["kind"] == "unknown"
assert not answer_has_overlay(parsed)
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def _blank(size=(160, 120)):
return Image.new("RGB", size, (0, 0, 0))
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
img = _blank()
parsed = parse_vqa_answer(
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
)
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_keypoint_overlay_changes_pixels():
img = _blank()
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()
def test_draw_overlay_non_spatial_leaves_image_unchanged():
img = _blank()
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
out = draw_vqa_overlay(img, parsed)
assert out.tobytes() == img.tobytes()
def test_draw_overlay_tolerates_malformed_coordinates():
img = _blank()
# bbox with the wrong arity must not raise.
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
assert out.size == img.size
def test_observation_image_to_pil_from_batched_float_array():
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
pil = observation_image_to_pil(arr)
assert pil.size == (32, 24)
assert pil.mode == "RGB"
# ---------------------------------------------------------------------------
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
# ---------------------------------------------------------------------------
def test_parse_loc_keypoint_answer():
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
assert parsed["kind"] == "keypoint"
assert parsed["normalized"] is True
x, y = parsed["payload"]["point"]
assert 0.24 < x < 0.26
assert 0.49 < y < 0.51
assert parsed["payload"]["label"] == "blue cube"
assert answer_has_overlay(parsed)
def test_parse_loc_bbox_answer():
# <locY0><locX0><locY1><locX1> label
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
assert parsed["kind"] == "bbox"
assert parsed["normalized"] is True
det = parsed["payload"]["detections"][0]
x1, y1, x2, y2 = det["bbox"]
assert x1 < x2 and y1 < y2
assert det["label"] == "yellow cube"
assert answer_has_overlay(parsed)
def test_parse_loc_multiple_boxes():
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
parsed = parse_vqa_answer(answer)
assert parsed["kind"] == "bbox"
assert len(parsed["payload"]["detections"]) == 2
def test_parse_loc_takes_precedence_over_json():
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
def test_draw_loc_overlay_denormalizes_to_pixels():
img = _blank((200, 100))
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
out = draw_vqa_overlay(img, parsed)
assert out.size == img.size
assert out.tobytes() != img.tobytes()