mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
examples(port_datasets): SLURM+datatrove RoboCasa composite_seen build
Parallel variant of build_robocasa_composite_seen.py modeled after the
existing slurm_port_shards.py / slurm_aggregate_shards.py pattern.
Two-phase datatrove pipeline:
* Phase 1 DOWNLOAD: tasks=16 (one per RoboCasa composite_seen task),
each worker downloads its assigned tar via RoboCasa's own
download_datasets helper. Network-bound, idempotent.
* Phase 2 AGGREGATE: tasks=1, single worker calls aggregate_datasets
over the 16 extracted directories. Submitted with depends=phase1 so
SLURM only releases it once all 16 downloads succeed.
Reuses the COMPOSITE_SEEN_TASKS list and per-task download/resolve
helpers from the single-machine script via aliased imports — single
source of truth for 'what does it mean to download a composite_seen
task'.
Local (--slurm 0) mode runs the two phases sequentially in-process for
debugging on a workstation.
Usage on SLURM:
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \
--output-dir=/scratch/${USER}/robocasa_composite_seen \
--hub-repo-id=${HF_USER}/robocasa_composite_seen \
--logs-dir=/scratch/${USER}/logs/robocasa \
--partition=cpu --push-to-hub
Prereq: uv sync --extra annotations (pulls datatrove)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,541 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Build the merged RoboCasa composite_seen dataset on SLURM via datatrove.
|
||||
|
||||
Two-phase pipeline modeled after ``slurm_port_shards.py`` +
|
||||
``slurm_aggregate_shards.py``:
|
||||
|
||||
* Phase 1 — DOWNLOAD (parallel, 16 tasks, 1 per worker):
|
||||
Each datatrove worker downloads one of the 16 RoboCasa composite_seen task
|
||||
archives (``v1.0/target/composite/<Task>/<date>/lerobot.tar``) via
|
||||
RoboCasa's own ``download_datasets`` helper. Idempotent — already-extracted
|
||||
tasks are skipped. Network-bound, CPU-light.
|
||||
|
||||
* Phase 2 — AGGREGATE (single worker, depends on phase 1):
|
||||
One worker calls ``aggregate_datasets`` over the 16 extracted directories,
|
||||
producing a single combined LeRobotDataset. Validates fps / robot_type /
|
||||
features, unifies task indices, concatenates videos + parquet, recomputes
|
||||
stats. CPU + disk heavy.
|
||||
|
||||
When run under SLURM, phase 2 is submitted with ``depends=phase_1_executor``
|
||||
so the scheduler only releases it after every download task succeeds.
|
||||
|
||||
Local (``--slurm 0``) execution runs the two phases sequentially in the
|
||||
current process — useful for debugging on a workstation.
|
||||
|
||||
Usage on SLURM::
|
||||
|
||||
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
|
||||
--output-dir=/scratch/${USER}/robocasa_composite_seen \\
|
||||
--hub-repo-id=${HF_USER}/robocasa_composite_seen \\
|
||||
--logs-dir=/scratch/${USER}/logs/robocasa \\
|
||||
--partition=cpu \\
|
||||
--download-cpus=4 --download-mem=8G \\
|
||||
--aggregate-cpus=16 --aggregate-mem-per-cpu=4G \\
|
||||
--push-to-hub
|
||||
|
||||
Local debug (sequential, single process)::
|
||||
|
||||
uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\
|
||||
--output-dir=/tmp/robocasa_composite_seen \\
|
||||
--slurm=0 \\
|
||||
--tasks=PrepareCoffee,KettleBoiling
|
||||
|
||||
Prereqs: ``robocasa`` + ``robosuite`` installed (see
|
||||
``docs/source/benchmarks/robocasa.mdx``); ``datatrove`` installed via the
|
||||
``annotations`` extra (``uv sync --extra annotations``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
# Reuse the per-task helpers + canonical task list from the single-machine
|
||||
# script so both runners share one source of truth for "what does it mean to
|
||||
# download a composite_seen task". The helpers are spelled with leading
|
||||
# underscores there (module-private), but the slurm runner is a legitimate
|
||||
# in-tree consumer, so we alias them to clean names here.
|
||||
from lerobot.scripts.build_robocasa_composite_seen import (
|
||||
COMPOSITE_SEEN_TASKS,
|
||||
_download_task as download_task,
|
||||
_require_robocasa as require_robocasa,
|
||||
_resolve_task_root as resolve_task_root,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pipeline steps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DownloadRoboCasaTask(PipelineStep):
|
||||
"""Phase 1 — download the task assigned to this rank.
|
||||
|
||||
Each datatrove worker is given a rank in ``[0, world_size)``. With
|
||||
``tasks == len(task_names)`` each worker owns exactly one task; with
|
||||
fewer workers than tasks, datatrove load-balances tasks across workers
|
||||
using the standard ``rank::world_size`` slicing.
|
||||
"""
|
||||
|
||||
def __init__(self, task_names: list[str], *, overwrite: bool = False):
|
||||
super().__init__()
|
||||
self.task_names = list(task_names)
|
||||
self.overwrite = overwrite
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from lerobot.utils.utils import init_logging # noqa: PLC0415
|
||||
|
||||
init_logging()
|
||||
require_robocasa()
|
||||
|
||||
# Standard datatrove slicing: each rank owns the subset
|
||||
# ``task_names[rank::world_size]``. When ``world_size ==
|
||||
# len(task_names)`` this is exactly one task per rank.
|
||||
my_tasks = self.task_names[rank::world_size]
|
||||
if not my_tasks:
|
||||
logger.info("rank %d/%d: no tasks assigned", rank, world_size)
|
||||
return
|
||||
|
||||
for task in my_tasks:
|
||||
logger.info("rank %d/%d: downloading %s", rank, world_size, task)
|
||||
root = download_task(task, overwrite=self.overwrite)
|
||||
logger.info("rank %d/%d: %s extracted at %s", rank, world_size, task, root)
|
||||
|
||||
|
||||
class AggregateRoboCasaShards(PipelineStep):
|
||||
"""Phase 2 — merge all 16 extracted directories into one LeRobotDataset.
|
||||
|
||||
``aggregate_datasets`` parallelizes internally; only rank 0 runs the
|
||||
merge (mirrors the DROID ``slurm_aggregate_shards.py`` convention).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_names: list[str],
|
||||
*,
|
||||
output_repo_id: str,
|
||||
output_dir: Path,
|
||||
push_to_hub: bool,
|
||||
private: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_names = list(task_names)
|
||||
self.output_repo_id = output_repo_id
|
||||
self.output_dir = Path(output_dir)
|
||||
self.push_to_hub = push_to_hub
|
||||
self.private = private
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from lerobot.utils.utils import init_logging # noqa: PLC0415
|
||||
|
||||
init_logging()
|
||||
|
||||
if rank != 0:
|
||||
logger.info("rank %d: aggregation runs on rank 0 only — skipping", rank)
|
||||
return
|
||||
|
||||
require_robocasa()
|
||||
from lerobot.datasets import aggregate_datasets # noqa: PLC0415
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Resolve each task's local extraction root. After phase 1 these
|
||||
# all exist on disk under robocasa.macros.DATASET_BASE_DIR; if any
|
||||
# are missing, fail loudly so the operator knows phase 1 didn't
|
||||
# cleanly complete for that task.
|
||||
roots: list[Path] = []
|
||||
missing: list[str] = []
|
||||
for task in self.task_names:
|
||||
root = resolve_task_root(task)
|
||||
if not root.exists():
|
||||
missing.append(f"{task} -> {root}")
|
||||
else:
|
||||
roots.append(root)
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
"Phase 1 did not produce extracted directories for: "
|
||||
+ ", ".join(missing)
|
||||
+ " — re-run the download phase before aggregating."
|
||||
)
|
||||
|
||||
# ``aggregate_datasets`` uses ``repo_ids`` purely for logging /
|
||||
# the unified task table when ``roots=`` is supplied; the actual
|
||||
# data is loaded from each root directly.
|
||||
repo_ids = [f"robocasa/{task}_target_human" for task in self.task_names]
|
||||
|
||||
logger.info(
|
||||
"Aggregating %d source datasets into %s at %s",
|
||||
len(roots),
|
||||
self.output_repo_id,
|
||||
self.output_dir,
|
||||
)
|
||||
aggregate_datasets(
|
||||
repo_ids=repo_ids,
|
||||
aggr_repo_id=self.output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=self.output_dir,
|
||||
)
|
||||
logger.info("Aggregation complete.")
|
||||
|
||||
if self.push_to_hub:
|
||||
merged = LeRobotDataset(
|
||||
repo_id=self.output_repo_id,
|
||||
root=self.output_dir,
|
||||
)
|
||||
logger.info(
|
||||
"Pushing %s to the Hub (private=%s, %d episodes, %d frames)",
|
||||
self.output_repo_id,
|
||||
self.private,
|
||||
merged.num_episodes,
|
||||
merged.num_frames,
|
||||
)
|
||||
merged.push_to_hub(
|
||||
private=self.private,
|
||||
upload_large_folder=True,
|
||||
tags=["lerobot", "robocasa", "composite_seen", "manipulation"],
|
||||
)
|
||||
logger.info(
|
||||
"Push complete: https://huggingface.co/datasets/%s",
|
||||
self.output_repo_id,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Executors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_download_executor(
|
||||
*,
|
||||
task_names: list[str],
|
||||
overwrite: bool,
|
||||
job_name: str,
|
||||
logs_dir: Path,
|
||||
workers: int,
|
||||
partition: str | None,
|
||||
cpus_per_task: int,
|
||||
mem: str,
|
||||
time: str,
|
||||
slurm: bool,
|
||||
):
|
||||
"""Phase-1 executor: parallel downloads, one task per worker by default."""
|
||||
pipeline = [DownloadRoboCasaTask(task_names, overwrite=overwrite)]
|
||||
logging_dir = str(logs_dir / job_name)
|
||||
|
||||
if slurm:
|
||||
return SlurmPipelineExecutor(
|
||||
pipeline=pipeline,
|
||||
logging_dir=logging_dir,
|
||||
job_name=job_name,
|
||||
tasks=len(task_names), # one shard per RoboCasa task
|
||||
workers=workers,
|
||||
time=time,
|
||||
partition=partition,
|
||||
cpus_per_task=cpus_per_task,
|
||||
sbatch_args={"mem": mem},
|
||||
)
|
||||
return LocalPipelineExecutor(
|
||||
pipeline=pipeline,
|
||||
logging_dir=logging_dir,
|
||||
tasks=len(task_names),
|
||||
workers=min(workers, len(task_names)),
|
||||
)
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
*,
|
||||
task_names: list[str],
|
||||
output_repo_id: str,
|
||||
output_dir: Path,
|
||||
push_to_hub: bool,
|
||||
private: bool,
|
||||
job_name: str,
|
||||
logs_dir: Path,
|
||||
partition: str | None,
|
||||
cpus_per_task: int,
|
||||
mem_per_cpu: str,
|
||||
time: str,
|
||||
slurm: bool,
|
||||
depends: SlurmPipelineExecutor | None,
|
||||
):
|
||||
"""Phase-2 executor: single worker, aggregates the extracted shards."""
|
||||
pipeline = [
|
||||
AggregateRoboCasaShards(
|
||||
task_names,
|
||||
output_repo_id=output_repo_id,
|
||||
output_dir=output_dir,
|
||||
push_to_hub=push_to_hub,
|
||||
private=private,
|
||||
)
|
||||
]
|
||||
logging_dir = str(logs_dir / job_name)
|
||||
|
||||
if slurm:
|
||||
return SlurmPipelineExecutor(
|
||||
pipeline=pipeline,
|
||||
logging_dir=logging_dir,
|
||||
job_name=job_name,
|
||||
tasks=1,
|
||||
workers=1,
|
||||
time=time,
|
||||
partition=partition,
|
||||
cpus_per_task=cpus_per_task,
|
||||
sbatch_args={"mem-per-cpu": mem_per_cpu},
|
||||
depends=depends, # SLURM job dependency on phase 1
|
||||
)
|
||||
return LocalPipelineExecutor(
|
||||
pipeline=pipeline,
|
||||
logging_dir=logging_dir,
|
||||
tasks=1,
|
||||
workers=1,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Build the merged RoboCasa composite_seen LeRobotDataset on SLURM "
|
||||
"via datatrove (download in parallel, aggregate sequentially)."
|
||||
),
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
|
||||
# I/O.
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Local directory for the merged dataset (will be created).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Hub repo_id for the merged dataset (e.g. "
|
||||
"``yourname/robocasa_composite_seen``). Required for "
|
||||
"``--push-to-hub``; also becomes the merged dataset's "
|
||||
"canonical ``repo_id``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the merged dataset to the Hub after aggregation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="When pushing, create the Hub repo as private.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=Path("./logs/robocasa"),
|
||||
help="Path to datatrove logs directory (used for stdout/stderr and "
|
||||
"phase coordination).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated task names overriding the default 16 "
|
||||
"composite_seen list (debug / smoke-test).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite-download",
|
||||
action="store_true",
|
||||
help="Force re-download even if the local extraction looks complete.",
|
||||
)
|
||||
|
||||
# SLURM controls.
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM (``1``) or locally / sequentially (``0``).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="SLURM partition. A CPU partition is sufficient — no GPU needed.",
|
||||
)
|
||||
|
||||
# Phase-1 (download) sizing.
|
||||
parser.add_argument(
|
||||
"--download-workers",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of parallel SLURM workers for the download phase. "
|
||||
"Default matches the number of composite_seen tasks (16).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-cpus",
|
||||
type=int,
|
||||
default=4,
|
||||
help="CPUs per download worker (the work is network- and "
|
||||
"tar-extraction-bound).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-mem",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Total memory per download worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download-time",
|
||||
type=str,
|
||||
default="06:00:00",
|
||||
help="SLURM wall-clock limit per download worker (HH:MM:SS).",
|
||||
)
|
||||
|
||||
# Phase-2 (aggregate) sizing.
|
||||
parser.add_argument(
|
||||
"--aggregate-cpus",
|
||||
type=int,
|
||||
default=16,
|
||||
help="CPUs for the aggregation worker (ffmpeg + parquet I/O parallelize).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregate-mem-per-cpu",
|
||||
type=str,
|
||||
default="2G",
|
||||
help="SLURM mem-per-cpu for the aggregation worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregate-time",
|
||||
type=str,
|
||||
default="12:00:00",
|
||||
help="SLURM wall-clock limit for aggregation (HH:MM:SS). Tens-of-GB "
|
||||
"merges can take several hours.",
|
||||
)
|
||||
|
||||
# Job naming.
|
||||
parser.add_argument(
|
||||
"--download-job-name",
|
||||
type=str,
|
||||
default="robocasa_dl",
|
||||
help="SLURM job name for phase 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--aggregate-job-name",
|
||||
type=str,
|
||||
default="robocasa_agg",
|
||||
help="SLURM job name for phase 2.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="[%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
tasks = (
|
||||
[t.strip() for t in args.tasks.split(",") if t.strip()]
|
||||
if args.tasks
|
||||
else list(COMPOSITE_SEEN_TASKS)
|
||||
)
|
||||
if not tasks:
|
||||
raise SystemExit("No tasks selected.")
|
||||
|
||||
if args.push_to_hub and not args.hub_repo_id:
|
||||
raise SystemExit("--push-to-hub requires --hub-repo-id.")
|
||||
|
||||
output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen"
|
||||
slurm = args.slurm == 1
|
||||
|
||||
logger.info(
|
||||
"Phase 1 (download) — %d tasks across %d workers (slurm=%s)",
|
||||
len(tasks),
|
||||
min(args.download_workers, len(tasks)),
|
||||
slurm,
|
||||
)
|
||||
download_executor = make_download_executor(
|
||||
task_names=tasks,
|
||||
overwrite=args.overwrite_download,
|
||||
job_name=args.download_job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
workers=args.download_workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.download_cpus,
|
||||
mem=args.download_mem,
|
||||
time=args.download_time,
|
||||
slurm=slurm,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Phase 2 (aggregate) — single worker, output: %s (push_to_hub=%s)",
|
||||
output_repo_id,
|
||||
args.push_to_hub,
|
||||
)
|
||||
aggregate_executor = make_aggregate_executor(
|
||||
task_names=tasks,
|
||||
output_repo_id=output_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
private=args.private,
|
||||
job_name=args.aggregate_job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.aggregate_cpus,
|
||||
mem_per_cpu=args.aggregate_mem_per_cpu,
|
||||
time=args.aggregate_time,
|
||||
slurm=slurm,
|
||||
depends=download_executor if slurm else None,
|
||||
)
|
||||
|
||||
if slurm:
|
||||
# Submitting the aggregate executor with ``depends=download_executor``
|
||||
# also submits the download executor — SlurmPipelineExecutor walks
|
||||
# the dependency chain and submits each job once with the right
|
||||
# ``--dependency=afterok:<jobid>`` arg.
|
||||
aggregate_executor.run()
|
||||
else:
|
||||
# Local sequential: run download to completion, then aggregate.
|
||||
download_executor.run()
|
||||
aggregate_executor.run()
|
||||
|
||||
logger.info("Done. Merged dataset at %s.", args.output_dir)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
+2
-6
@@ -308,12 +308,8 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
# Interactive hierarchical-VLA runtime. The same entry point drives both
|
||||
# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the
|
||||
# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is
|
||||
# an alias so the command name isn't misleading for PI052 users.
|
||||
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -129,8 +129,8 @@ class Executor:
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Persist the tool catalog to meta/info.json so chat-template
|
||||
# consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read
|
||||
# Persist the tool catalog to meta/info.json so downstream
|
||||
# consumers (PI052 / Pi0.5 / dataset visualizer) can read
|
||||
# it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and
|
||||
# additive: anything the user pre-populated is preserved; we only
|
||||
# ensure the canonical ``say`` schema is present.
|
||||
|
||||
@@ -20,11 +20,10 @@
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# SmolVLA2 note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the chat tokenizer
|
||||
# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM
|
||||
# head learns to emit exactly the marker the runtime parses back
|
||||
# (`_split_plan_and_say`).
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
|
||||
@@ -20,11 +20,10 @@
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# SmolVLA2 note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the chat tokenizer
|
||||
# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM
|
||||
# head learns to emit exactly the marker the runtime parses back
|
||||
# (`_split_plan_and_say`).
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# subtasks_vqa — Hi-Robot blend, shared between SmolVLA2 (SmolVLM2
|
||||
# backbone) and PI052 (PaliGemma backbone).
|
||||
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
|
||||
#
|
||||
# Trains two things only: subtasks and VQA. Plan and memory are
|
||||
# intentionally left out — keeps the prompt short and the training
|
||||
@@ -10,9 +9,8 @@
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Each backbone's text tokenizer renders these messages differently
|
||||
# (SmolVLA2 uses the chat template; PI052 concatenates as plain
|
||||
# ``Role: content`` text), but the recipe spec is identical.
|
||||
# PI052's text tokenizer renders these messages as plain
|
||||
# ``Role: content`` text (PaliGemma is not chat-pretrained).
|
||||
|
||||
blend:
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -52,7 +51,6 @@ __all__ = [
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"SmolVLA2Config",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"WallXConfig",
|
||||
|
||||
@@ -144,10 +144,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
@@ -180,8 +176,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
|
||||
"pi052", "sac", "smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -212,10 +208,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -424,17 +416,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla2":
|
||||
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
|
||||
# MUST come before the SmolVLAConfig isinstance check below
|
||||
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
|
||||
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
|
||||
|
||||
processors = make_smolvla2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ This is the dual-head co-training pattern from the paper:
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which mirrors what ``SmolVLA2Runtime`` already does on a
|
||||
SmolVLM2 backbone.
|
||||
step, which the multi-rate ``PI052Runtime`` (in
|
||||
``lerobot.policies.pi052.inference``) drives at separate rates.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -48,10 +48,9 @@ from ..pi05.configuration_pi05 import PI05Config
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head
|
||||
config. Same recipe-driven training surface; the only difference is
|
||||
which backbone the policy uses (PaliGemma here vs SmolVLM2 there).
|
||||
The flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
Recipe-driven dual-head training: the flow head supervises actions,
|
||||
the LM head supervises subtask / plan / memory / VQA text. The
|
||||
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
@@ -64,17 +63,17 @@ class PI052Config(PI05Config):
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template. So unlike SmolVLA2 we don't apply one. The recipe
|
||||
renderer's output is concatenated as a plain prefix + assistant
|
||||
suffix instead, mirroring how the π0.5 paper's high-level inference
|
||||
samples text auto-regressively after the prefix."""
|
||||
chat template, so we don't apply one. The recipe renderer's output
|
||||
is concatenated as a plain prefix + assistant suffix instead,
|
||||
mirroring how the π0.5 paper's high-level inference samples text
|
||||
auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D uses α=10 between the flow and text terms, assuming
|
||||
# text is a rare auxiliary task. With the recipe stack the flow-only
|
||||
# `low_level` branch fires on a large share of samples, so α=10
|
||||
# swamps the LM head and collapses generation into degenerate
|
||||
# repetition. We use the milder 5:1 split (matches SmolVLA2Config).
|
||||
# repetition. We use the milder 5:1 split here.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
@@ -93,8 +92,8 @@ class PI052Config(PI05Config):
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Same regulariser surface as SmolVLA2: randomly drop non-target
|
||||
# context messages so the LM head learns to handle missing /
|
||||
# Randomly drop non-target context messages so the LM head learns
|
||||
# to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
|
||||
+5
-5
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 inference / runtime orchestration.
|
||||
"""PI052 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
@@ -22,12 +22,12 @@ Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
|
||||
calls ``run()``.
|
||||
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
|
||||
``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import SmolVLA2Runtime
|
||||
from .runtime import PI052Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
@@ -44,7 +44,7 @@ from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"SmolVLA2Runtime",
|
||||
"PI052Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
+3
-3
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the SmolVLA2 runtime.
|
||||
"""Stdin REPL event collector for the PI052 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
@@ -23,7 +23,7 @@ Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin
|
||||
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
@@ -92,7 +92,7 @@ class StdinReader:
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[smolvla2] Task: {task}", flush=True)
|
||||
print(f"[pi052] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
+4
-4
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 runtime loop.
|
||||
"""PI052 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmolVLA2Runtime:
|
||||
class PI052Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
@@ -195,11 +195,11 @@ class SmolVLA2Runtime:
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[smolvla2] {line}", flush=True)
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[smolvla2] runtime stopped", flush=True)
|
||||
print("[pi052] runtime stopped", flush=True)
|
||||
+8
-130
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the SmolVLA2 multi-rate runtime.
|
||||
"""Inference steps for the PI052 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
@@ -98,7 +98,7 @@ class LowLevelForward(InferenceStep):
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
# SmolVLA produces *action chunks* (typically 50 steps via
|
||||
# PI052 produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). Every step gets dispatched to the robot;
|
||||
# popping one per dispatch tick is essentially free. Only
|
||||
# generate a new chunk once the previous one has fully
|
||||
@@ -256,34 +256,11 @@ def _build_text_batch(
|
||||
) -> dict[str, Any]:
|
||||
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||
|
||||
Dispatches on the policy backbone so one runtime drives both:
|
||||
|
||||
* ``smolvla2`` (SmolVLM2) — chat template via ``apply_chat_template``.
|
||||
* ``pi052`` (PaliGemma) — flat ``Role: content`` text, since
|
||||
PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``).
|
||||
"""
|
||||
if getattr(getattr(policy, "config", None), "type", "") == "pi052":
|
||||
return _build_text_batch_pi052(
|
||||
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return _build_text_batch_chat(
|
||||
policy, prompt_messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
|
||||
|
||||
def _build_text_batch_pi052(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""PI052 text batch — flat ``User: … \\nAssistant: …`` prompt.
|
||||
|
||||
PaliGemma ships no chat template, so PI052 trains on the plain
|
||||
role-prefixed concatenation built by ``PI052TextTokenizerStep``.
|
||||
Reuses that exact formatter so the inference prefix matches
|
||||
training. ``add_generation_prompt`` appends the bare ``Assistant: ``
|
||||
header the LM head continues from.
|
||||
PI052's backbone (PaliGemma) ships no chat template, so we train on
|
||||
a plain role-prefixed concatenation built by
|
||||
``PI052TextTokenizerStep``. We reuse that exact formatter so the
|
||||
inference prefix matches training; ``add_generation_prompt`` appends
|
||||
the bare ``Assistant: `` header the LM head continues from.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
@@ -315,99 +292,6 @@ def _build_text_batch_pi052(
|
||||
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||
attn = attn.bool()
|
||||
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _build_text_batch_chat(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""SmolVLA2 (SmolVLM2) text batch — chat-template tokenization.
|
||||
|
||||
Reuses ``_strip_lerobot_blocks`` so the inference prompt shape
|
||||
matches the training-time chat tokenizer step exactly.
|
||||
"""
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Reuse the *exact* normaliser that the training-time chat
|
||||
# tokenizer step uses (``_strip_lerobot_blocks``). It handles all
|
||||
# the cases the SmolVLM chat template expects:
|
||||
# * ``content: list[block]`` → keep text blocks, drop images
|
||||
# * ``content: None`` → ``[{type: text, text: ""}]``
|
||||
# * ``content: str`` / anything else → ``[{type: text, text: str(content)}]``
|
||||
# Doing it any other way creates a training/inference mismatch in
|
||||
# exactly the prompt shape the model was supervised on. Also
|
||||
# strips ``stream`` / ``target`` recipe metadata.
|
||||
from lerobot.policies.smolvla2.chat_processor_smolvla2 import ( # noqa: PLC0415
|
||||
_strip_lerobot_blocks,
|
||||
)
|
||||
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in prompt_messages]
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# ``apply_chat_template`` can return any of:
|
||||
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
|
||||
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
|
||||
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
|
||||
# (newer transformers, especially via processor.apply_chat_template
|
||||
# forwarding through here).
|
||||
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
|
||||
# attention mask when available so we don't have to re-derive it
|
||||
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
|
||||
attn: Any = None
|
||||
if hasattr(encoded, "input_ids"):
|
||||
ids = encoded.input_ids
|
||||
attn = getattr(encoded, "attention_mask", None)
|
||||
elif isinstance(encoded, dict) and "input_ids" in encoded:
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
else:
|
||||
ids = encoded
|
||||
if isinstance(ids, list):
|
||||
if ids and isinstance(ids[0], list):
|
||||
ids = ids[0]
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
ids = torch.tensor(ids, dtype=torch.long)
|
||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
elif isinstance(attn, list):
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
attn = torch.tensor(attn, dtype=torch.long)
|
||||
if attn.ndim == 1:
|
||||
attn = attn.unsqueeze(0)
|
||||
# SmolVLA's ``eager_attention_forward`` does
|
||||
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
|
||||
# requires a *bool* condition tensor; ``BatchEncoding``'s
|
||||
# attention_mask is typically Long (0/1). Cast so the prefix
|
||||
# forward doesn't blow up with ``where expected condition to be a
|
||||
# boolean tensor, but got a tensor with dtype Long``.
|
||||
if attn is not None and hasattr(attn, "dtype"):
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
if attn.dtype != _torch.bool:
|
||||
attn = attn.bool()
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
@@ -418,7 +302,7 @@ def _build_text_batch_chat(
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move lang tokens to %s: %s", device, exc)
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
@@ -1022,12 +906,6 @@ def _generate_with_policy(
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
# Only pass ``suppress_loc_tokens`` to backbones that accept it
|
||||
# (pi052). SmolVLA2's ``select_message`` does not, so we omit
|
||||
# the kwarg there to avoid breaking the shared runtime.
|
||||
import inspect # noqa: PLC0415
|
||||
|
||||
if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters:
|
||||
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
|
||||
return policy.select_message(batch, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
+1
-1
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
|
||||
"""Trigger primitives for PI052's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
+2
-2
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the SmolVLA2 runtime.
|
||||
"""Rich-based REPL layout for the PI052 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
@@ -97,7 +97,7 @@ def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
)
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}",
|
||||
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
+2
-2
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interactive VQA for the SmolVLA2 runtime.
|
||||
"""Interactive VQA for the PI052 runtime.
|
||||
|
||||
In ``/vlm`` mode a typed line is treated as a VQA question. This module
|
||||
runs the full interactive flow:
|
||||
@@ -379,7 +379,7 @@ def handle_vqa_query(
|
||||
# Feed the FULL observation (every camera + state) to the VLM. The
|
||||
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
|
||||
# stripped before tokenization — the actual frames reach the model
|
||||
# via SmolVLA's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
|
||||
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
|
||||
# consumes *all* ``config.image_features`` regardless of which
|
||||
# camera the sub-recipe was tagged for. So the model always sees
|
||||
# every camera; the operator never has to name one to ask.
|
||||
@@ -29,10 +29,10 @@ A thin subclass of :class:`PI05Policy` that:
|
||||
|
||||
with α controllable via ``config.flow_loss_weight``.
|
||||
|
||||
The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see
|
||||
``lerobot.policies.smolvla2.inference``) can drive this policy too —
|
||||
both expose ``predict_action_chunk`` for the action expert and
|
||||
``select_message`` for the LM head.
|
||||
The multi-rate inference runtime in ``lerobot.policies.pi052.inference``
|
||||
(driven by the ``lerobot-pi052-runtime`` CLI) sits on top of this:
|
||||
``predict_action_chunk`` for the action expert and ``select_message``
|
||||
for the LM head.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -380,9 +380,9 @@ class PI052Policy(PI05Policy):
|
||||
for p in backbone.lm_head.parameters():
|
||||
p.requires_grad_(True)
|
||||
# The text model's final norm and last transformer block —
|
||||
# mirror SmolVLA2's logic, which finds these dynamically by
|
||||
# the trainable=False parameters that point at the head's
|
||||
# neighbourhood.
|
||||
# find them dynamically by walking up from the LM head so we
|
||||
# don't hard-code module names that may drift across transformers
|
||||
# versions.
|
||||
text_model = getattr(backbone, "model", None)
|
||||
text_model = getattr(text_model, "language_model", text_model)
|
||||
if text_model is None:
|
||||
@@ -934,9 +934,8 @@ class PI052Policy(PI05Policy):
|
||||
) -> str:
|
||||
"""Generate text continuation from a multimodal prefix.
|
||||
|
||||
Mirrors ``SmolVLA2Policy.select_message`` so the same
|
||||
:class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime`
|
||||
can drive π0.5 v2 unchanged.
|
||||
Consumed by :class:`lerobot.policies.pi052.inference.PI052Runtime`
|
||||
for the high-level / VQA / memory-update text streams.
|
||||
|
||||
``suppress_loc_tokens`` masks PaliGemma's reserved ``<locDDDD>``
|
||||
ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's
|
||||
|
||||
@@ -182,8 +182,7 @@ def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/`` (same lookup rules as
|
||||
``make_smolvla2_pre_post_processors``).
|
||||
``src/lerobot/configs/``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
"""π0.5 v2 text-tokenisation step.
|
||||
|
||||
PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on
|
||||
PaliGemma is *not* chat-pretrained, so we can't lean on
|
||||
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
|
||||
messages as plain text with simple ``User: ... Assistant: ...`` role
|
||||
delimiters — matching the prompt format π0.5 uses in the paper
|
||||
@@ -30,8 +30,7 @@ Outputs:
|
||||
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
|
||||
those positions via the PaliGemma ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``. Same
|
||||
semantics as the SmolVLA2 step.
|
||||
target messages has ``message_streams[i] == "low_level"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -54,11 +53,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the
|
||||
# matching SmolVLA2 implementation. Behaviour: when
|
||||
# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on
|
||||
# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans
|
||||
# the LM head will be supervised on.
|
||||
# Debug helper — when ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N
|
||||
# samples processed (on rank 0) are pretty-printed with ``[TGT]...[/TGT]``
|
||||
# markers over the spans the LM head will be supervised on.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
|
||||
@@ -246,8 +243,8 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
# values exceeding the camera's pixel dimensions — they're not pixels.)
|
||||
# Converting to ``<loc>`` is therefore camera-resolution-independent:
|
||||
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
|
||||
# not in the dataset — so the dataset stays backbone-agnostic (SmolVLA2
|
||||
# keeps the JSON).
|
||||
# not in the dataset — so the dataset keeps the raw JSON and stays
|
||||
# backbone-agnostic.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The 0–1000 scale Qwen2.5-VL emits for grounding coordinates.
|
||||
@@ -424,8 +421,8 @@ def _format_messages(
|
||||
class PI052TextTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template;
|
||||
concatenates messages as ``User: ... \\nAssistant: ...`` text.
|
||||
No chat template; concatenates messages as
|
||||
``User: ... \\nAssistant: ...`` text.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
@@ -613,8 +610,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
continue
|
||||
labels[token_pos] = input_ids[token_pos]
|
||||
|
||||
# Scan ALL message streams (not just targets) — see
|
||||
# ``chat_processor_smolvla2.py`` for rationale: the v2
|
||||
# Scan ALL message streams (not just targets): the
|
||||
# ``low_level_execution`` recipe drops ``target: true`` on
|
||||
# the assistant to avoid trivial copy-from-user text-CE; the
|
||||
# flow loss still needs to fire, gated by ``stream: low_level``.
|
||||
@@ -686,7 +682,7 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
|
||||
|
||||
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Heuristic content-prefix classifier — mirrors SmolVLA2's."""
|
||||
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only does
|
||||
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
|
||||
model can train on the full Hi Robot / MEM / ECoT message blend defined in
|
||||
the steerable annotation plan (PR1 + PR2):
|
||||
|
||||
* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss
|
||||
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
|
||||
``user_interjection_response``, ``high_level_subtask``) → CE loss on
|
||||
``lm_head`` over the recipe's target message tokens
|
||||
* mixed sub-recipes → both losses summed (weighted)
|
||||
|
||||
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
|
||||
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
|
||||
sample, else ``False``.
|
||||
|
||||
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
|
||||
the model code stays in one place — only the dual-loss path and the
|
||||
chat-template processor live here.
|
||||
"""
|
||||
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
__all__ = ["SmolVLA2Config"]
|
||||
@@ -1,649 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2's chat-template tokenization step.
|
||||
|
||||
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
|
||||
``recipe_path`` is set. Reads the rendered messages produced by
|
||||
``RenderMessagesStep`` (PR 1) and produces:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` —
|
||||
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
|
||||
``tools=meta.tools`` (PR 1's catalog).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
the positions belonging to messages whose index is in
|
||||
``target_message_indices``. The next commit's modeling forward path
|
||||
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``. The
|
||||
modeling forward uses this to gate the flow head.
|
||||
|
||||
Image / video content blocks in the rendered messages are dropped
|
||||
before tokenization — the chat template only handles text, and SmolVLA
|
||||
already passes camera tensors out-of-band via the standard
|
||||
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
|
||||
(``embed_prefix`` puts image embeddings before language embeddings,
|
||||
matching the chat-template-stripped text order).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug helper: dump the first N rendered samples to stdout so you can sanity-
|
||||
# check what the model actually sees before kicking off a long training run.
|
||||
#
|
||||
# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ...
|
||||
#
|
||||
# Prints the recipe-rendered messages, the chat-templated text (decoded back
|
||||
# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans
|
||||
# are supervised by text-CE. Stops after N total dumps to keep training logs
|
||||
# tractable. Rank-0 only when accelerate sets ``RANK``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
|
||||
_DUMPED_SO_FAR = 0
|
||||
|
||||
|
||||
def _is_dump_rank() -> bool:
|
||||
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
|
||||
try:
|
||||
return int(rank) == 0
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
|
||||
def _dump_recipe_sample(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
token_ids: list[int],
|
||||
labels: list[int],
|
||||
predict_actions: bool,
|
||||
tokenizer: Any,
|
||||
) -> None:
|
||||
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
|
||||
global _DUMPED_SO_FAR
|
||||
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
|
||||
return
|
||||
_DUMPED_SO_FAR += 1
|
||||
|
||||
decoded = tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
parts: list[str] = []
|
||||
i = 0
|
||||
while i < len(labels):
|
||||
if labels[i] == -100:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] == -100:
|
||||
j += 1
|
||||
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
|
||||
i = j
|
||||
else:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] != -100:
|
||||
j += 1
|
||||
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
|
||||
parts.append(f"[TGT]{tgt_text}[/TGT]")
|
||||
i = j
|
||||
annotated = "".join(parts)
|
||||
|
||||
n_tgt = sum(1 for l in labels if l != -100)
|
||||
print(
|
||||
"\n========== RECIPE SAMPLE DUMP "
|
||||
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
|
||||
flush=True,
|
||||
)
|
||||
print(f" predict_actions: {predict_actions}", flush=True)
|
||||
print(f" rendered messages ({len(messages)}):", flush=True)
|
||||
for m in messages:
|
||||
stream = m.get("stream")
|
||||
target = m.get("target")
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
print(f" - role={role} stream={stream} target={target}", flush=True)
|
||||
print(f" content: {content!r}", flush=True)
|
||||
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
|
||||
print(f" decoded (raw):\n {decoded}", flush=True)
|
||||
print(f" decoded (with target markers):\n {annotated}", flush=True)
|
||||
print("==============================================\n", flush=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
|
||||
class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
This is the bridge between the recipe stack (PR 1's
|
||||
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
|
||||
(next commit, which reads ``text_labels`` / ``predict_actions``).
|
||||
Pure-text turns and multi-stream targets are both handled.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||
max_length: int = 2048
|
||||
padding: str = "longest"
|
||||
padding_side: str = "right"
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
# --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up
|
||||
# ``feat/pi05-prompt-dropout``). At training, drop non-target
|
||||
# messages whose content was substituted from the named recipe
|
||||
# binding with the given probability. Forces the model to handle
|
||||
# missing context at inference — directly attacks the memorisation
|
||||
# collapse where ``current_subtask=""`` puts the prompt OOD. All
|
||||
# default to 0.0 (no dropout) so behaviour is identical until
|
||||
# explicitly opted in via the training config.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
# Optional seed for the per-sample RNG. ``None`` ⇒ use
|
||||
# ``sample_idx`` derived from the transition (when present), so
|
||||
# dropout is reproducible across runs but varies per sample.
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Lazy: don't load the tokenizer until the step actually runs,
|
||||
# so unit tests that import the module without transformers
|
||||
# installed still pass.
|
||||
self._tokenizer: Any = None
|
||||
if self.tools is None:
|
||||
# Default: no tools rendered into the system prompt. The
|
||||
# ``say()`` tool was only used by the now-removed
|
||||
# ``user_interjection_response`` recipe; including its
|
||||
# schema on every sample adds a long system message to
|
||||
# the action expert's prefix and creates a train/inference
|
||||
# mismatch (the inference low-level loop doesn't pass
|
||||
# tools=, so the chat template doesn't render them).
|
||||
# Users who actually need tools can set them via
|
||||
# ``with_tools(meta.tools)``.
|
||||
self.tools = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
|
||||
"""Override the tools catalog rendered into the system prompt."""
|
||||
self.tools = list(tools)
|
||||
return self
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
messages = comp.get("messages")
|
||||
if not messages:
|
||||
# No recipe rendering happened — nothing to do; downstream
|
||||
# falls back to whatever ``task`` is in the transition.
|
||||
return transition
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
# Pull a sample_idx for the dropout RNG. ``index`` is the
|
||||
# canonical per-frame key on ``LeRobotDataset`` samples and
|
||||
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
|
||||
# absent (e.g. inference) we fall back to 0 which is harmless
|
||||
# because the dropout probs are also 0 at inference time.
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(comp.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
sorted(int(i) for i in tgt_indices),
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
)
|
||||
for msg, streams, tgt_indices, s_idx in zip(
|
||||
messages,
|
||||
comp.get("message_streams") or [[] for _ in messages],
|
||||
comp.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(comp.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(comp.get("message_streams") or []),
|
||||
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
]
|
||||
|
||||
# Optional first-N-samples debug dump for sanity-checking what the
|
||||
# model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES``
|
||||
# is set; stops globally after the budget is exhausted.
|
||||
if _DUMP_BUDGET > 0:
|
||||
# Stream / target metadata live in parallel arrays in
|
||||
# COMPLEMENTARY_DATA, not on the message dicts themselves
|
||||
# (the recipe renderer keeps them separate so the chat
|
||||
# template doesn't choke on unknown keys). Zip them back
|
||||
# together for the dumper so each printed message shows
|
||||
# its actual stream + target flag.
|
||||
if _is_batched_messages(messages):
|
||||
msgs_iter = messages
|
||||
streams_iter = comp.get("message_streams") or [[] for _ in messages]
|
||||
targets_iter = comp.get("target_message_indices") or [[] for _ in messages]
|
||||
else:
|
||||
msgs_iter = [messages]
|
||||
streams_iter = [list(comp.get("message_streams") or [])]
|
||||
targets_iter = [list(comp.get("target_message_indices") or [])]
|
||||
for msg, streams, targets, (ids, labels, predict_action) in zip(
|
||||
msgs_iter, streams_iter, targets_iter, encoded, strict=False
|
||||
):
|
||||
target_set = {int(i) for i in targets}
|
||||
annotated_msgs = []
|
||||
for i, m in enumerate(msg):
|
||||
annotated_msgs.append(
|
||||
{
|
||||
**m,
|
||||
"stream": streams[i] if i < len(streams) else None,
|
||||
"target": True if i in target_set else None,
|
||||
}
|
||||
)
|
||||
_dump_recipe_sample(
|
||||
messages=annotated_msgs,
|
||||
token_ids=ids,
|
||||
labels=labels,
|
||||
predict_actions=predict_action,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
target_length = self.max_length if self.padding == "max_length" else max(
|
||||
len(ids) for ids, _, _ in encoded
|
||||
)
|
||||
target_length = min(target_length, self.max_length)
|
||||
|
||||
ids_batch = []
|
||||
attn_batch = []
|
||||
labels_batch = []
|
||||
predict_actions = []
|
||||
for ids, labels, predict_action in encoded:
|
||||
ids = ids[:target_length]
|
||||
labels = labels[:target_length]
|
||||
attn = [1] * len(ids)
|
||||
if len(ids) < target_length:
|
||||
n_pad = target_length - len(ids)
|
||||
ids = ids + [pad_id] * n_pad
|
||||
labels = labels + [-100] * n_pad
|
||||
attn = attn + [0] * n_pad
|
||||
ids_batch.append(ids)
|
||||
attn_batch.append(attn)
|
||||
labels_batch.append(labels)
|
||||
predict_actions.append(predict_action)
|
||||
|
||||
ids_t = torch.tensor(ids_batch, dtype=torch.long)
|
||||
attn_t = torch.tensor(attn_batch, dtype=torch.bool)
|
||||
labels_t = torch.tensor(labels_batch, dtype=torch.long)
|
||||
predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool)
|
||||
|
||||
if not _is_batched_messages(messages):
|
||||
ids_t = ids_t.squeeze(0)
|
||||
attn_t = attn_t.squeeze(0)
|
||||
labels_t = labels_t.squeeze(0)
|
||||
predict_actions_t = predict_actions_t.squeeze(0)
|
||||
|
||||
new_complementary = dict(comp)
|
||||
# Drop the per-recipe sidecar keys; everything downstream needs
|
||||
# is now in the tokenized form.
|
||||
new_complementary.pop("messages", None)
|
||||
new_complementary.pop("message_streams", None)
|
||||
new_complementary.pop("target_message_indices", None)
|
||||
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
|
||||
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
|
||||
# them there — and drop ``task`` so the upstream
|
||||
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
|
||||
# tokenize.
|
||||
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
observation[OBS_LANGUAGE_TOKENS] = ids_t
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
|
||||
new_complementary["text_labels"] = labels_t
|
||||
new_complementary["predict_actions"] = predict_actions_t
|
||||
new_complementary.pop("task", None)
|
||||
|
||||
new_transition = dict(transition)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[list[int], list[int], bool]:
|
||||
# Apply per-component prompt dropout *before* tokenisation, so
|
||||
# the dropped messages don't contribute tokens or label-mask
|
||||
# positions at all. Re-maps ``target_indices`` to account for
|
||||
# removed messages.
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages, target_indices, sample_idx
|
||||
)
|
||||
# Flatten ``tool_calls`` into a textual ``<say>...</say>`` marker
|
||||
# *before* the chat template sees them, so the model is trained
|
||||
# to emit the same marker the inference parser
|
||||
# (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``.
|
||||
messages = [_flatten_say_tool_calls(m) for m in messages]
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_ids = _as_token_ids(full_ids)
|
||||
|
||||
labels = [-100] * len(full_ids)
|
||||
for tgt in target_indices:
|
||||
prefix_ids = tokenizer.apply_chat_template(
|
||||
text_messages[:tgt],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_through_target = tokenizer.apply_chat_template(
|
||||
text_messages[: tgt + 1],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
prefix_ids = _as_token_ids(prefix_ids)
|
||||
full_through_target = _as_token_ids(full_through_target)
|
||||
start = len(prefix_ids)
|
||||
end = min(len(full_through_target), len(full_ids))
|
||||
for pos in range(start, end):
|
||||
labels[pos] = int(full_ids[pos])
|
||||
|
||||
# ``predict_actions`` is True iff this sample's recipe declares
|
||||
# at least one ``low_level`` message — regardless of whether
|
||||
# it's a target. The ``low_level_execution`` recipe in v2 uses
|
||||
# ``stream: low_level`` on both user and assistant turns but
|
||||
# only renders the *user* subtask (no text-CE target on the
|
||||
# assistant) to avoid trivial "copy previous turn" supervision.
|
||||
# Scanning targets alone would miss this sample's action loss.
|
||||
predict_actions = any(s == "low_level" for s in message_streams)
|
||||
return [int(i) for i in full_ids], labels, predict_actions
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
sample_idx: int | None,
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Probabilistically drop non-target context messages.
|
||||
|
||||
Heuristic content sniffing — matches the prefix strings that
|
||||
``subtask_mem_vqa_speech.yaml``'s recipes use when injecting plan /
|
||||
memory / subtask / interjection content. Anything else is
|
||||
kept unchanged. Target messages are never dropped (we still
|
||||
need their tokens for supervision).
|
||||
|
||||
Returns ``(new_messages, new_target_indices)`` where the
|
||||
indices are re-mapped to point at the same target turns in
|
||||
the trimmed list.
|
||||
"""
|
||||
probs = {
|
||||
"plan": float(self.plan_dropout_prob or 0.0),
|
||||
"memory": float(self.memory_dropout_prob or 0.0),
|
||||
"subtask": float(self.subtask_dropout_prob or 0.0),
|
||||
"interjection": float(self.interjection_dropout_prob or 0.0),
|
||||
}
|
||||
if not any(p > 0.0 for p in probs.values()):
|
||||
return messages, target_indices
|
||||
|
||||
# Deterministic per-sample RNG so dropout is reproducible
|
||||
# across runs (matters for debugging / repro) but varies
|
||||
# frame-to-frame.
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0)
|
||||
rng = random.Random(int(seed_int) & 0xFFFFFFFF)
|
||||
|
||||
target_set = set(target_indices)
|
||||
keep_flags: list[bool] = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i in target_set:
|
||||
keep_flags.append(True)
|
||||
continue
|
||||
kind = _classify_message_for_dropout(msg)
|
||||
if kind and rng.random() < probs.get(kind, 0.0):
|
||||
keep_flags.append(False)
|
||||
else:
|
||||
keep_flags.append(True)
|
||||
|
||||
new_messages = [m for m, keep in zip(messages, keep_flags) if keep]
|
||||
# Re-map target_indices: each old index drops by the count of
|
||||
# falsy flags before it.
|
||||
new_target_indices: list[int] = []
|
||||
for old_idx in target_indices:
|
||||
dropped_before = sum(1 for k in keep_flags[:old_idx] if not k)
|
||||
new_target_indices.append(old_idx - dropped_before)
|
||||
return new_messages, sorted(new_target_indices)
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass-through; this step writes runtime tensors not features."""
|
||||
return features
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_tokenizer(self): # noqa: ANN202
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SmolVLA2ChatTokenizerStep requires transformers. "
|
||||
"`pip install lerobot[transformers-dep]`."
|
||||
) from exc
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
|
||||
self._tokenizer.pad_token = self._tokenizer.eos_token
|
||||
return self._tokenizer
|
||||
|
||||
|
||||
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
|
||||
|
||||
The recipe DSL allows authors to write multimodal content like
|
||||
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
|
||||
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
|
||||
``path``). The actual image tensor flows through SmolVLA's
|
||||
``OBS_IMAGES_*`` channels separately; the chat template only needs
|
||||
the text. So we strip non-text blocks before tokenizing.
|
||||
"""
|
||||
new = dict(message)
|
||||
content = new.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
|
||||
new["content"] = text_parts or [{"type": "text", "text": ""}]
|
||||
elif content is None:
|
||||
new["content"] = [{"type": "text", "text": ""}]
|
||||
else:
|
||||
new["content"] = [{"type": "text", "text": str(content)}]
|
||||
if "tool_calls" in new and not new["tool_calls"]:
|
||||
# Drop empty tool_calls — some templates render them as a
|
||||
# spurious empty marker.
|
||||
new.pop("tool_calls")
|
||||
# ``stream`` and ``target`` were recipe metadata; templates don't
|
||||
# know them and may warn or crash.
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to plain text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize assistant ``say`` tool calls into a textual ``<say>...</say>``
|
||||
marker inside the message content (Pi 0.5-style flat tool-call
|
||||
serialization).
|
||||
|
||||
SmolVLM's chat template would otherwise render ``tool_calls`` as a
|
||||
structured JSON ``<tool_call>`` block, so the LM head learns to emit
|
||||
JSON — but the inference parser ``_split_plan_and_say`` looks for a
|
||||
``<say>...</say>`` marker (``_SAY_RE``). Rewriting the call into the
|
||||
content text *before* ``apply_chat_template`` aligns the two: the
|
||||
template only ever tokenizes plain text, and the supervised target
|
||||
span trains the model to produce the exact marker the runtime reads.
|
||||
|
||||
Messages without ``say`` tool calls are returned unchanged.
|
||||
"""
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
return message
|
||||
|
||||
say_texts: list[str] = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
fn = call.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
args = json.loads(args)
|
||||
except (ValueError, TypeError):
|
||||
args = {}
|
||||
text = args.get("text", "") if isinstance(args, dict) else ""
|
||||
if text:
|
||||
say_texts.append(str(text))
|
||||
|
||||
if not say_texts:
|
||||
# No ``say`` calls (or empty text) — drop the structured calls so
|
||||
# the template doesn't render a stray JSON block, but leave the
|
||||
# content alone.
|
||||
new = dict(message)
|
||||
new.pop("tool_calls", None)
|
||||
return new
|
||||
|
||||
new = dict(message)
|
||||
base = _content_to_text(new.get("content")).strip()
|
||||
marker = "".join(f"<say>{t}</say>" for t in say_texts)
|
||||
new["content"] = f"{base}\n{marker}" if base else marker
|
||||
new.pop("tool_calls", None)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Best-effort classification of which recipe binding contributed
|
||||
to this message, used for per-component dropout.
|
||||
|
||||
The canonical recipe authors plan/memory/subtask injections with
|
||||
distinctive prefix strings in the rendered content. Matching on
|
||||
those prefixes is brittle if a future recipe author uses
|
||||
different wording — but it's also localised to one place and
|
||||
only affects the dropout fraction (never the actual semantics).
|
||||
Returns ``None`` for messages we don't recognise; those are
|
||||
always kept.
|
||||
"""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str):
|
||||
text_parts.append(t)
|
||||
content = "\n".join(text_parts)
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
head = content.lstrip().lower()
|
||||
if head.startswith("plan:") or head.startswith("previous plan"):
|
||||
return "plan"
|
||||
if head.startswith("memory:") or head.startswith("previous memory"):
|
||||
return "memory"
|
||||
if head.startswith("current subtask") or head.startswith("completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
|
||||
|
||||
def _as_token_ids(value: Any) -> list[int]:
|
||||
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
|
||||
value = value["input_ids"]
|
||||
if hasattr(value, "tolist"):
|
||||
value = value.tolist()
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
return [int(i) for i in value]
|
||||
|
||||
|
||||
# Re-export for tests / introspection
|
||||
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||
flatten_say_tool_calls = _flatten_say_tool_calls
|
||||
@@ -1,163 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..smolvla.configuration_smolvla import SmolVLAConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(SmolVLAConfig):
|
||||
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only
|
||||
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
|
||||
same model can train on:
|
||||
|
||||
* **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss
|
||||
on the action expert, same as SmolVLA. ``predict_actions=True``.
|
||||
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
|
||||
``user_interjection_response`` / ``high_level_subtask``) — cross-
|
||||
entropy loss on the LM head over the recipe's target message tokens.
|
||||
Skips the flow head entirely. ``predict_actions=False``.
|
||||
* **mixed sub-recipes** — both heads run, losses summed (weighted).
|
||||
|
||||
The split is controlled by ``predict_actions = bool(targets_by_stream
|
||||
.get("low_level"))`` per the Pi0.5 convention in the steerable
|
||||
annotation plan (Section I.7), implemented inside the processor /
|
||||
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
|
||||
|
||||
Compared to ``SmolVLAConfig`` this adds:
|
||||
|
||||
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
|
||||
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
|
||||
task-only path so unannotated datasets still work.
|
||||
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
|
||||
both losses are active in a single sample.
|
||||
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn —
|
||||
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
|
||||
need to undo that for SmolVLA2.
|
||||
- ``train_expert_only=False`` by default, since the VLM body now also
|
||||
participates in text-target gradients.
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
|
||||
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
|
||||
rendering and fall back to SmolVLA's single-task prompt path
|
||||
(unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = True
|
||||
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
|
||||
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
|
||||
matches its training distribution."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Pi 0.5 paper §IV.D (Eq. 1) sets α = 10 between the text-CE term
|
||||
# and the flow-MSE term: L = H(text) + α * ‖ω - a - f_θ‖². The
|
||||
# rationale is that actions are the primary output and the flow
|
||||
# head should dominate the gradient signal; text is supervised as
|
||||
# an auxiliary task and its CE scale (~0.5-2.0 in nats) tends to
|
||||
# be larger than the flow MSE scale (~0.1-1.0), so without
|
||||
# up-weighting the action head gets starved. We use a milder
|
||||
# split (5:1) than the paper's α=10: ~40% of the blend is the
|
||||
# flow-only ``low_level`` recipe, so the flow term already fires
|
||||
# often, and α=10 starved the text head into degenerate decoding.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 5.0
|
||||
"""Weight on the action-expert flow-matching term. Default 5.0 — a
|
||||
milder split than the Pi 0.5 paper's α=10 (§IV.D), since the
|
||||
flow-only ``low_level`` recipe already gives the action expert
|
||||
frequent gradient. Set lower if the text head is underfitting
|
||||
relative to the action expert; set higher if the action expert is
|
||||
degrading because text loss dominates."""
|
||||
|
||||
# Optimizer -----------------------------------------------------------
|
||||
optimizer_lr: float = 2.5e-5
|
||||
"""Peak learning rate. Overrides ``SmolVLAConfig``'s ``1e-4``.
|
||||
|
||||
SmolVLA can afford ``1e-4`` because it *freezes* the language head —
|
||||
only the from-scratch action expert sees that LR. SmolVLA2 unfreezes
|
||||
``lm_head`` + the last text layer and fine-tunes the **pretrained**
|
||||
SmolVLM2 language weights, and ``1e-4`` is too aggressive for a
|
||||
pretrained LM: it destabilises the language representations and
|
||||
collapses generation into degenerate repetition. ``2.5e-5`` matches
|
||||
pi05's peak LR (openpi ``CosineDecaySchedule``), the comparable
|
||||
text-co-trained policy. The action expert trains slightly slower at
|
||||
this LR, so budget more steps."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
|
||||
preceding norm + last text-model layer that SmolVLA freezes). Must be
|
||||
``True`` for the text head to learn. Setting this to ``False``
|
||||
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
||||
which is occasionally useful for ablations."""
|
||||
|
||||
load_vlm_weights: bool = True
|
||||
"""Load the pretrained SmolVLM2 backbone weights (vision tower +
|
||||
language model + ``lm_head``) instead of random-initialising them.
|
||||
|
||||
``SmolVLAConfig`` defaults this to ``False`` because the original
|
||||
SmolVLA pre-training run trained the VLM body itself. For SmolVLA2
|
||||
that default is a footgun: the text head **is** the SmolVLM2
|
||||
``lm_head``, and the high-level subtask supervision is hopeless if
|
||||
it starts from a random language model — it can only memorise.
|
||||
SmolVLA2 therefore defaults this to ``True`` so every run fine-tunes
|
||||
from the pretrained ``vlm_model_name`` checkpoint
|
||||
(``HuggingFaceTB/SmolVLM2-500M-Video-Instruct``).
|
||||
|
||||
Note this loads the *VLM backbone* pretrained; the action expert
|
||||
still trains from scratch on the robot data (standard SmolVLA
|
||||
fine-tuning). To also start the action expert from pretrained
|
||||
weights, fine-tune from a full ``lerobot/smolvla_base`` checkpoint
|
||||
via ``--policy.path``."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# At training, randomly drop non-target context messages whose
|
||||
# content was substituted from the named recipe binding. Forces
|
||||
# the model to handle missing context — directly attacks the
|
||||
# memorisation collapse where a stale or missing plan/memory at
|
||||
# inference puts the prompt out-of-distribution and the LM head
|
||||
# falls back to dominant-mode fragments. All default to 0.0 so
|
||||
# behaviour is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Plan:`` or ``Previous plan``
|
||||
with this probability per sample."""
|
||||
memory_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Memory:`` or ``Previous memory``
|
||||
with this probability per sample."""
|
||||
subtask_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Current subtask`` or
|
||||
``Completed subtask`` with this probability per sample."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through its text path when the
|
||||
# LM head is producing supervised text. Override the SmolVLA
|
||||
# default (`train_expert_only=True`) unless the user explicitly
|
||||
# opts out of text training via `text_loss_weight=0`.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
# The user can still flip this back via CLI; this only
|
||||
# changes the *default* when SmolVLA2 is actually training a
|
||||
# text head.
|
||||
self.train_expert_only = False
|
||||
@@ -1,692 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
Adds:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that runs the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
|
||||
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
|
||||
this branch).
|
||||
|
||||
Per-sample routing — within one batch:
|
||||
|
||||
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
|
||||
loss (action chunk supervision).
|
||||
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
|
||||
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
|
||||
contribute to the LM-head cross-entropy.
|
||||
|
||||
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
|
||||
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
|
||||
datasets keep working unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, int]:
|
||||
"""Find ``[lang_start, lang_end)`` inside the SmolVLA prefix.
|
||||
|
||||
``embed_prefix`` lays out the prefix as
|
||||
``[image_blocks..., lang, state, padding]`` with the att-mask
|
||||
convention ``[0]*image, [0]*lang, [1]*state, [0]*padding`` (see
|
||||
``modeling_smolvla.SmolVLAModel.embed_prefix``). State is exactly
|
||||
one token, and it's the *only* position with ``att_mask == 1``,
|
||||
so we use the first ``1`` to anchor lang_end. Computing it this
|
||||
way is robust to (a) state being projected to one embedding token
|
||||
regardless of its raw feature dim, and (b) the trailing padding
|
||||
added when ``seq_len < prefix_length``.
|
||||
"""
|
||||
row = prefix_att_masks[0]
|
||||
ones = row.nonzero(as_tuple=False)
|
||||
if ones.numel() == 0:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: state token not found in prefix att_masks — "
|
||||
"can't locate language range."
|
||||
)
|
||||
state_start = int(ones[0, 0].item())
|
||||
lang_end = state_start
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: lang range underflows prefix "
|
||||
f"(state_start={state_start}, num_lang={num_lang})."
|
||||
)
|
||||
return lang_start, lang_end
|
||||
|
||||
|
||||
def _mark_target_span_causal(
|
||||
prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int
|
||||
) -> Tensor:
|
||||
"""Make the supervised text-target span causally masked.
|
||||
|
||||
``embed_prefix`` flags every language token with ``att=0``, which
|
||||
``make_att_2d_masks`` turns into one fully *bidirectional* block —
|
||||
so a target token's hidden state attends to the very tokens it is
|
||||
supposed to predict. The text cross-entropy then degenerates into
|
||||
a copy task (loss → ~0) and the model never learns causal
|
||||
next-token prediction — at inference, where ``select_message``
|
||||
decodes autoregressively (causally), it collapses.
|
||||
|
||||
Fix: set ``att=1`` on the language positions that are supervised
|
||||
targets (``text_labels != -100``). With ``make_att_2d_masks``'s
|
||||
cumulative-block rule each target token then attends to images +
|
||||
the user prompt bidirectionally and to *earlier* target tokens
|
||||
only — i.e. genuine causal next-token prediction, matching
|
||||
inference. Non-target language (the user prompt, and the
|
||||
``low_level_execution`` subtask which is a user turn, not a
|
||||
target) stays ``att=0`` / bidirectional. The action expert is
|
||||
unaffected: the suffix has a strictly higher cumsum so it still
|
||||
attends to every prefix token.
|
||||
"""
|
||||
att = prefix_att_masks.clone()
|
||||
n = min(text_labels.shape[1], lang_end - lang_start)
|
||||
if n <= 0:
|
||||
return att
|
||||
target = text_labels[:, :n] != -100 # (B, n) bool
|
||||
seg = att[:, lang_start : lang_start + n].bool()
|
||||
att[:, lang_start : lang_start + n] = seg | target
|
||||
return att
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
||||
num_lang = logits.shape[1]
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss / valid.sum().clamp(min=1)
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, **kwargs):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
config = SmolVLA2Config(
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
}
|
||||
)
|
||||
super().__init__(config, **kwargs)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backbone surgery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the text-output path so the LM head
|
||||
loss can flow back.
|
||||
|
||||
SmolVLA's ``set_requires_grad`` freezes three things when
|
||||
``train_expert_only=False``: ``lm_head``,
|
||||
``text_model.model.norm.weight``, and the last 1-2 text-model
|
||||
transformer layers (see ``smolvlm_with_expert.py:167-176``).
|
||||
We must unfreeze *all three* — otherwise gradients still die
|
||||
in the frozen final block and the lm_head learns nothing.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
return
|
||||
vlm = getattr(vlm_with_expert, "vlm", None)
|
||||
if vlm is None:
|
||||
return
|
||||
|
||||
# Mirror the freeze targets from ``smolvlm_with_expert.set_requires_grad``.
|
||||
num_vlm = getattr(vlm_with_expert, "num_vlm_layers", None)
|
||||
num_expert = getattr(vlm_with_expert, "num_expert_layers", None)
|
||||
last_layers = []
|
||||
if num_vlm is not None:
|
||||
last_layers.append(num_vlm - 1)
|
||||
if (
|
||||
num_expert is not None
|
||||
and num_vlm != num_expert
|
||||
and num_vlm % num_expert == 0
|
||||
):
|
||||
last_layers.append(num_vlm - 2)
|
||||
unfreeze_prefixes = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
*[f"text_model.model.layers.{layer}." for layer in last_layers],
|
||||
]
|
||||
for name, param in vlm.named_parameters():
|
||||
if any(k in name for k in unfreeze_prefixes):
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
noise: Tensor | None = None,
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional dual-head loss.
|
||||
|
||||
Two routing knobs from the batch (produced by
|
||||
:class:`SmolVLA2ChatTokenizerStep`):
|
||||
|
||||
* ``text_labels`` — per-token labels with ``-100`` for non-target
|
||||
positions. Triggers the text-loss path through ``lm_head``.
|
||||
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
|
||||
include this sample's action chunk in the flow loss.
|
||||
|
||||
When neither is present, delegate to ``SmolVLAPolicy.forward``.
|
||||
"""
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
has_text_data = (
|
||||
text_labels is not None
|
||||
and isinstance(text_labels, Tensor)
|
||||
and self.config.text_loss_weight > 0
|
||||
)
|
||||
has_per_sample_routing = (
|
||||
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
|
||||
)
|
||||
|
||||
if not has_text_data and not has_per_sample_routing:
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
|
||||
loss_dict: dict[str, Any] = {}
|
||||
device = batch[OBS_STATE].device
|
||||
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||
|
||||
run_flow = (
|
||||
self.config.flow_loss_weight > 0
|
||||
and ACTION in batch
|
||||
and (not has_per_sample_routing or bool(predict_actions_t.any().item()))
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Fused path — one backbone forward for flow + text together.
|
||||
# ------------------------------------------------------------
|
||||
if run_flow and has_text_data:
|
||||
flow_loss, text_loss, flow_diag = self._compute_fused_loss(
|
||||
batch, text_labels, predict_actions_t, noise=noise, time=time
|
||||
)
|
||||
total = total + self.config.flow_loss_weight * flow_loss
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
for k, v in flow_diag.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
elif run_flow:
|
||||
per_sample_flow, flow_diag = super().forward(
|
||||
batch, noise=noise, time=time, reduction="none"
|
||||
)
|
||||
if has_per_sample_routing:
|
||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||
masked = per_sample_flow * mask
|
||||
denom = mask.sum().clamp(min=1.0)
|
||||
flow_loss = masked.sum() / denom
|
||||
else:
|
||||
flow_loss = per_sample_flow.mean()
|
||||
total = total + self.config.flow_loss_weight * flow_loss
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
for k, v in flow_diag.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
elif has_text_data:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
else:
|
||||
# No path fired — happens when both loss weights are 0 or
|
||||
# the batch has neither action samples nor supervised text.
|
||||
# Fail loud rather than train silently on a zero loss.
|
||||
raise RuntimeError(
|
||||
"SmolVLA2Policy.forward: nothing to train — "
|
||||
"flow_loss_weight=%s, text_loss_weight=%s, "
|
||||
"predict_actions.any()=%s, has_text_data=%s"
|
||||
% (
|
||||
self.config.flow_loss_weight,
|
||||
self.config.text_loss_weight,
|
||||
bool(predict_actions_t.any().item()) if has_per_sample_routing else None,
|
||||
has_text_data,
|
||||
)
|
||||
)
|
||||
|
||||
loss_dict["loss"] = float(total.detach().item())
|
||||
|
||||
if reduction == "none":
|
||||
# Per-sample loss isn't meaningfully defined for the dual
|
||||
# path; broadcast the scalar to (B,) for caller compat.
|
||||
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
|
||||
return total, loss_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text-loss internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
# Causally mask the supervised target span so the text-CE is
|
||||
# genuine next-token prediction (see ``_mark_target_span_causal``).
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
prefix_att_masks = _mark_target_span_causal(
|
||||
prefix_att_masks, text_labels, lang_start, lang_end
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prefix-only forward.
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
# ``lang_start`` / ``lang_end`` were located above on the
|
||||
# *unmodified* att masks — don't recompute here, because
|
||||
# ``_mark_target_span_causal`` set target lang tokens to 1 and
|
||||
# ``_locate_lang_range`` keys on the first 1 (the state token).
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
return _shifted_ce(logits, text_labels)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fused flow + text loss (single backbone forward)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_fused_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
text_labels: Tensor,
|
||||
predict_actions_t: Tensor | None,
|
||||
noise: Tensor | None = None,
|
||||
time: Tensor | None = None,
|
||||
) -> tuple[Tensor, Tensor, dict[str, Any]]:
|
||||
"""One backbone forward → both flow MSE and text CE.
|
||||
|
||||
Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one
|
||||
``vlm_with_expert`` call) but captures **both** outputs:
|
||||
|
||||
* ``prefix_out[:, lang_start:lang_end]`` → ``lm_head`` → CE on
|
||||
``text_labels`` (same slicing as ``_compute_text_loss``).
|
||||
* ``suffix_out[:, -chunk_size:]`` → ``action_out_proj`` → flow
|
||||
MSE against ``noise - actions`` (same as the parent forward).
|
||||
|
||||
Saves one backbone pass per training step vs. running the flow
|
||||
and text paths separately — same trick PI052Policy uses in
|
||||
``_compute_all_losses_fused``.
|
||||
"""
|
||||
cfg = self.config
|
||||
if cfg.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
inner = self.model
|
||||
if noise is None:
|
||||
noise = inner.sample_noise(actions.shape, actions.device)
|
||||
if time is None:
|
||||
time = inner.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
# Causally mask the supervised text-target span (see
|
||||
# ``_mark_target_span_causal``). Per-sample: high_level_subtask
|
||||
# samples have a subtask target → causal; low_level_execution
|
||||
# samples have all -100 labels → untouched / bidirectional, so
|
||||
# the action expert still reads the subtask as bidirectional
|
||||
# context. ``lang_start`` / ``lang_end`` located here on the
|
||||
# unmodified mask and reused for the text-loss slice below.
|
||||
lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
prefix_att_masks = _mark_target_span_causal(
|
||||
prefix_att_masks, text_labels, lang_start, lang_end
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
out_pair, _ = inner.vlm_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
prefix_out, suffix_out = out_pair[0], out_pair[1]
|
||||
if prefix_out is None or suffix_out is None:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: fused forward expected both prefix and suffix "
|
||||
"hidden states from vlm_with_expert."
|
||||
)
|
||||
|
||||
# ---------------- flow loss (per-sample maskable) ----------------
|
||||
chunk = cfg.chunk_size
|
||||
suffix_chunk = suffix_out[:, -chunk:].to(torch.float32)
|
||||
v_t = inner.action_out_proj(suffix_chunk)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
original_action_dim = cfg.action_feature.shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
flow_diag = {"losses_after_forward": float(losses.detach().mean().item())}
|
||||
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
if actions_is_pad is not None:
|
||||
in_episode = ~actions_is_pad
|
||||
losses = losses * in_episode.unsqueeze(-1)
|
||||
flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item())
|
||||
|
||||
losses = losses[:, :, : cfg.max_action_dim]
|
||||
flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item())
|
||||
|
||||
per_sample_flow = losses.mean(dim=(1, 2))
|
||||
if predict_actions_t is not None:
|
||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||
flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0)
|
||||
else:
|
||||
flow_loss = per_sample_flow.mean()
|
||||
|
||||
# ---------------- text loss (lang slice of prefix) ---------------
|
||||
# ``lang_start`` / ``lang_end`` from above (unmodified mask).
|
||||
vlm = inner.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden)
|
||||
text_loss = _shifted_ce(logits, text_labels)
|
||||
|
||||
return flow_loss, text_loss, flow_diag
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference: text generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def select_message(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
*,
|
||||
max_new_tokens: int = 256,
|
||||
min_new_tokens: int = 0,
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
) -> str:
|
||||
"""Generate text continuation from the chat-templated prompt.
|
||||
|
||||
AR decoding with KV caching reused from SmolVLA's inference
|
||||
path. Batch size is assumed to be 1 (the runtime calls this
|
||||
per-event). Returns the decoded string of new tokens (the
|
||||
prompt itself is not included).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch:
|
||||
Already through the SmolVLA2 preprocessor — expects
|
||||
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
|
||||
``OBS_LANGUAGE_ATTENTION_MASK``.
|
||||
max_new_tokens:
|
||||
Hard cap on generated tokens; stops earlier on EOS.
|
||||
eos_token_id:
|
||||
Override the tokenizer's EOS. ``None`` ⇒ use the
|
||||
tokenizer's default.
|
||||
temperature, top_p:
|
||||
``temperature=0`` does greedy argmax (default — matches
|
||||
training distribution most closely). Set ``temperature>0``
|
||||
with optional ``top_p<1`` for nucleus sampling.
|
||||
tokenizer:
|
||||
Optional pre-loaded tokenizer to avoid the cold-start
|
||||
``AutoTokenizer.from_pretrained`` round-trip on every call.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
# Build the full set of special-token ids to suppress during
|
||||
# the ``min_new_tokens`` window. EOS alone is not enough on a
|
||||
# memorised SmolVLM head — when EOS is masked, the argmax
|
||||
# falls onto a sibling special token (``<|im_end|>``,
|
||||
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
|
||||
# …) which then survives generation but gets stripped by
|
||||
# ``skip_special_tokens=True`` so ``decode`` returns an empty
|
||||
# string and the runtime sees ``last_raw='(empty)'`` every
|
||||
# chunk boundary.
|
||||
special_ids_set: set[int] = set()
|
||||
try:
|
||||
for sid in (tokenizer.all_special_ids or []):
|
||||
if sid is not None:
|
||||
special_ids_set.add(int(sid))
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if eos_token_id is not None:
|
||||
special_ids_set.add(int(eos_token_id))
|
||||
|
||||
# Match training's text-loss forward path (see
|
||||
# ``_compute_text_loss`` above): build the full prefix via
|
||||
# ``embed_prefix`` so images + state conditioning is intact,
|
||||
# then loop AR with ``fill_kv_cache=True, use_cache=False``.
|
||||
# That flag combo routes every layer through
|
||||
# ``forward_attn_layer`` (which gracefully skips ``None``
|
||||
# expert inputs via ``if hidden_states is None or layer is
|
||||
# None: continue``) and short-circuits the cache-update logic
|
||||
# so we don't have to manage past_kv. Each step just
|
||||
# re-forwards the cumulative ``[prefix + generated]``
|
||||
# sequence.
|
||||
#
|
||||
# This is O(n²) in generated text length but cheap in
|
||||
# absolute terms: image encoding happens once via the initial
|
||||
# ``embed_prefix`` call, and the per-step cost is just one
|
||||
# SmolVLM transformer pass over a sequence that grows by one
|
||||
# token each time. Standard SmolVLM ``generate`` was the
|
||||
# other tempting path, but it can't accept SmolVLA's custom
|
||||
# ``state_proj`` output and its tile-grid expectations
|
||||
# disagree with our preprocessor — both lead to garbage
|
||||
# generation, which is what the prior approach produced.
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
|
||||
# ``embed_prefix`` lays the prefix out as ``[images, lang, state]``
|
||||
# — the state token is LAST. Training supervises the text head on
|
||||
# the *language* positions (see ``_compute_text_loss`` /
|
||||
# ``_compute_fused_loss``: lm_head over ``prefix_out[lang_start:
|
||||
# lang_end]``). So AR text generation must continue from the last
|
||||
# language token (right after the ``Assistant:`` generation
|
||||
# prompt) — NOT from the state token, whose hidden state exists
|
||||
# for the action expert to read and which the lm_head was never
|
||||
# trained to decode subtask text from. Truncating the state token
|
||||
# here makes ``prefix_out[:, -1:]`` in the loop below the last
|
||||
# language position, matching the training distribution.
|
||||
_, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1])
|
||||
prefix_embs = prefix_embs[:, :lang_end]
|
||||
prefix_pad_masks = prefix_pad_masks[:, :lang_end]
|
||||
prefix_att_masks = prefix_att_masks[:, :lang_end]
|
||||
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
|
||||
current_embs = prefix_embs
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
generated: list[int] = []
|
||||
for _ in range(max_new_tokens):
|
||||
full_2d = make_att_2d_masks(current_pad, current_att)
|
||||
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
||||
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=full_2d,
|
||||
position_ids=full_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"select_message: vlm_with_expert.forward returned no hidden states."
|
||||
)
|
||||
|
||||
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
# Suppress *all* special tokens until we've decoded
|
||||
# ``min_new_tokens`` real (renderable) tokens. Without
|
||||
# this, a memorised SmolVLM head whose argmax at position
|
||||
# 0 is a special token produces an empty completion every
|
||||
# time — either EOS directly, or (after we mask EOS) the
|
||||
# argmax shifts to a sibling special id (``<|im_end|>``,
|
||||
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
|
||||
# via ``skip_special_tokens=True``. Masking the full
|
||||
# ``all_special_ids`` set for the first N steps forces
|
||||
# the head to commit to a normal vocabulary token before
|
||||
# it can close (or quietly poison) the turn.
|
||||
if special_ids_set and len(generated) < min_new_tokens:
|
||||
for sid in special_ids_set:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||
next_ids.unsqueeze(0)
|
||||
)
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
# When the visible decoded string is empty but tokens *were*
|
||||
# generated, expose what those raw tokens decoded to without
|
||||
# the special-token filter. This is what the runtime turns
|
||||
# into a scrollback line when ``last_raw='(empty)'`` so the
|
||||
# operator can tell whether the head is emitting EOS, image
|
||||
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
|
||||
# or something else.
|
||||
if not decoded and generated:
|
||||
try:
|
||||
self._last_select_message_debug = (
|
||||
f"raw_ids={generated[:16]} "
|
||||
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
self._last_select_message_debug = f"raw_ids={generated[:16]}"
|
||||
else:
|
||||
self._last_select_message_debug = ""
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(
|
||||
logits: Tensor, temperature: float, top_p: float
|
||||
) -> Tensor:
|
||||
"""Pick one token id per batch row from ``logits``."""
|
||||
if temperature <= 0.0:
|
||||
return logits.argmax(dim=-1)
|
||||
scaled = logits / max(temperature, 1e-6)
|
||||
probs = F.softmax(scaled, dim=-1)
|
||||
if top_p < 1.0:
|
||||
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
|
||||
cum = sorted_probs.cumsum(dim=-1)
|
||||
mask = cum > top_p
|
||||
# Always keep the most-likely token.
|
||||
mask[..., 0] = False
|
||||
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
|
||||
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||||
pick = torch.multinomial(sorted_probs, num_samples=1)
|
||||
return sorted_idx.gather(-1, pick).squeeze(-1)
|
||||
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 processor pipelines.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
RenderMessagesStep(recipe) # PR 1: language_* → messages
|
||||
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
|
||||
DeviceProcessorStep
|
||||
NormalizerProcessorStep
|
||||
|
||||
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
|
||||
plain task-string pipeline so unannotated datasets still work.
|
||||
|
||||
Post-processor is unchanged from SmolVLA.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
RenderMessagesStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def make_smolvla2_pre_post_processors(
|
||||
config: SmolVLA2Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build SmolVLA2's pre/post-processor pipelines.
|
||||
|
||||
With ``recipe_path`` set, inserts the recipe-rendering step and the
|
||||
chat-template tokenizer that emits ``text_labels`` and
|
||||
``predict_actions`` for the dual-loss path. Without it, falls back
|
||||
to SmolVLA's plain task-string pipeline so unannotated datasets
|
||||
keep working unchanged.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
SmolVLA2ChatTokenizerStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding=config.pad_language_to,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/`` so recipe authors can write
|
||||
``--policy.recipe_path=recipes/subtasks_vqa.yaml``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
+44
-48
@@ -12,10 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
|
||||
"""``lerobot-pi052-runtime`` — interactive REPL for trained PI052.
|
||||
|
||||
Drives the multi-rate runtime defined in
|
||||
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
|
||||
:mod:`lerobot.policies.pi052.inference`. Stdin becomes the user
|
||||
channel: type a task, then natural-language interjections / questions.
|
||||
The runtime prints state changes (plan / subtask / memory / vqa /
|
||||
speech) as they happen.
|
||||
@@ -26,16 +26,16 @@ Examples
|
||||
Dry run on a Hub checkpoint, no robot connected — useful for sanity-
|
||||
checking text generation::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||
uv run lerobot-pi052-runtime \\
|
||||
--policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\
|
||||
--no_robot \\
|
||||
--task="please clean the kitchen"
|
||||
|
||||
Same, but feed real frames from an annotated dataset so plan / subtask
|
||||
/ memory / VQA generation runs against actual video + state::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\
|
||||
uv run lerobot-pi052-runtime \\
|
||||
--policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\
|
||||
--dataset.repo_id=pepijn223/super_poulain_annotated \\
|
||||
--dataset.episode=0 \\
|
||||
--no_robot \\
|
||||
@@ -43,7 +43,7 @@ Same, but feed real frames from an annotated dataset so plan / subtask
|
||||
|
||||
With a real robot::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
uv run lerobot-pi052-runtime \\
|
||||
--policy.path=... \\
|
||||
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||
--tts.voice=alba
|
||||
@@ -62,18 +62,14 @@ import logging
|
||||
import sys
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||
logger = logging.getLogger("lerobot.pi052.runtime")
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
# prog defaults to the invoked command name, so this reads
|
||||
# correctly whether run as lerobot-smolvla2-runtime or
|
||||
# lerobot-pi052-runtime.
|
||||
description=(
|
||||
"Interactive REPL runtime for a trained hierarchical VLA "
|
||||
"checkpoint (SmolVLA2 or PI052 — policy type is read from "
|
||||
"the checkpoint)."
|
||||
"Interactive REPL runtime for a trained PI052 hierarchical "
|
||||
"VLA checkpoint."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
@@ -82,7 +78,7 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Local directory or Hugging Face Hub repo id pointing at a trained SmolVLA2 ``pretrained_model``."
|
||||
"Local directory or Hugging Face Hub repo id pointing at a trained PI052 ``pretrained_model``."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
@@ -368,7 +364,7 @@ def _load_policy_and_preprocessor(
|
||||
policy_path: str,
|
||||
dataset_repo_id: str | None,
|
||||
) -> tuple[Any, Any, Any, Any]:
|
||||
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
|
||||
"""Load a PI052 checkpoint (local path or Hub repo id).
|
||||
|
||||
Returns ``(policy, preprocessor, postprocessor, ds_meta)``.
|
||||
``preprocessor`` / ``postprocessor`` / ``ds_meta`` are ``None``
|
||||
@@ -430,7 +426,7 @@ def _build_observation_provider(
|
||||
|
||||
The dataset's ``language_persistent`` / ``language_events``
|
||||
columns are stripped before the sample reaches the preprocessor,
|
||||
so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are
|
||||
so ``RenderMessagesStep`` and ``PI052TextTokenizerStep`` are
|
||||
no-ops; the runtime supplies its own messages from current state.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
@@ -613,7 +609,7 @@ def _select_task_interactively(
|
||||
# bootstrap default (may be None — REPL handles that).
|
||||
return bootstrap_task
|
||||
|
||||
print("\n[smolvla2] Select startup task:", flush=True)
|
||||
print("\n[pi052] Select startup task:", flush=True)
|
||||
if options:
|
||||
for i, opt in enumerate(options, 1):
|
||||
marker = " (dataset default)" if opt == bootstrap_task else ""
|
||||
@@ -816,7 +812,7 @@ def _build_robot_observation_provider(
|
||||
# ``ds_features``. The training distribution sees frames at the
|
||||
# recorded resolution (e.g. 480×640); a live Mac/USB camera will
|
||||
# almost always hand us a different native size (720p / 1080p).
|
||||
# SmolVLA's internal ``resize_with_pad(512, 512)`` does pad the
|
||||
# PI052's internal ``resize_with_pad(512, 512)`` does pad the
|
||||
# input to a fixed canvas, but the *geometry* of that pad differs
|
||||
# by input aspect ratio — top/left padding varies, so the visual
|
||||
# tokens at each tile carry different content than what the model
|
||||
@@ -1017,7 +1013,7 @@ def _build_robot_action_executor(
|
||||
def _print_runtime_help() -> None:
|
||||
"""Print the slash-command reference."""
|
||||
print(
|
||||
"[smolvla2] commands (arguments need no quotes):\n"
|
||||
"[pi052] commands (arguments need no quotes):\n"
|
||||
" /action <task> run the robot; an argument switches to that task\n"
|
||||
" /action resume the robot on the current task\n"
|
||||
" /action <seconds> run the robot for N seconds, then auto-pause\n"
|
||||
@@ -1085,7 +1081,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
|
||||
secs = float(rest)
|
||||
runtime.state["action_deadline"] = _time.monotonic() + secs
|
||||
print(
|
||||
f"[smolvla2] action — running {secs:g}s, then auto-pause",
|
||||
f"[pi052] action — running {secs:g}s, then auto-pause",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
@@ -1095,16 +1091,16 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
|
||||
# New task → drop the stale subtask so the high-level
|
||||
# loop regenerates one for the new goal.
|
||||
runtime.state["current_subtask"] = None
|
||||
print(f"[smolvla2] action — task: {rest!r}", flush=True)
|
||||
print(f"[pi052] action — task: {rest!r}", flush=True)
|
||||
elif runtime.state.get("task"):
|
||||
print(
|
||||
f"[smolvla2] action — resuming: {runtime.state['task']!r}",
|
||||
f"[pi052] action — resuming: {runtime.state['task']!r}",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
runtime.state["mode"] = "paused"
|
||||
print(
|
||||
"[smolvla2] no task set — use /action <your task>",
|
||||
"[pi052] no task set — use /action <your task>",
|
||||
flush=True,
|
||||
)
|
||||
return True
|
||||
@@ -1113,7 +1109,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
|
||||
runtime.state["mode"] = "paused"
|
||||
runtime.state["action_deadline"] = None
|
||||
_clear_action_queue(runtime)
|
||||
print("[smolvla2] paused — robot holding position", flush=True)
|
||||
print("[pi052] paused — robot holding position", flush=True)
|
||||
return True
|
||||
|
||||
if cmd in {"/question", "/q", "/ask", "/vqa", "/vlm"}:
|
||||
@@ -1124,7 +1120,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool:
|
||||
_clear_action_queue(runtime)
|
||||
if not rest:
|
||||
print(
|
||||
"[smolvla2] usage: /question <your question> "
|
||||
"[pi052] usage: /question <your question> "
|
||||
"(e.g. /question point to the yellow cube)",
|
||||
flush=True,
|
||||
)
|
||||
@@ -1144,7 +1140,7 @@ def _run_vqa_query(runtime: Any, question: str) -> None:
|
||||
Invoked by ``/question`` — the action loop is paused first so the
|
||||
policy is free for a synchronous VQA call.
|
||||
"""
|
||||
from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415
|
||||
from lerobot.policies.pi052.inference.vqa import handle_vqa_query # noqa: PLC0415
|
||||
|
||||
handle_vqa_query(
|
||||
policy=runtime.policy,
|
||||
@@ -1180,11 +1176,11 @@ def _run_autonomous(
|
||||
if not auto_start and runtime.state.get("mode", "paused") == "action":
|
||||
try:
|
||||
input(
|
||||
"[smolvla2] Robot connected — starting in ACTION mode. "
|
||||
"[pi052] Robot connected — starting in ACTION mode. "
|
||||
"Press ENTER to begin, Ctrl+C to abort. "
|
||||
)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n[smolvla2] aborted before start", flush=True)
|
||||
print("\n[pi052] aborted before start", flush=True)
|
||||
return 130
|
||||
|
||||
if initial_task:
|
||||
@@ -1193,7 +1189,7 @@ def _run_autonomous(
|
||||
thread = threading.Thread(
|
||||
target=runtime.run,
|
||||
kwargs={"max_ticks": max_ticks},
|
||||
name="smolvla2-runtime-loop",
|
||||
name="pi052-runtime-loop",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
@@ -1251,7 +1247,7 @@ def _run_autonomous(
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
print(
|
||||
"\n[smolvla2] timed action elapsed — paused",
|
||||
"\n[pi052] timed action elapsed — paused",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
@@ -1264,7 +1260,7 @@ def _run_autonomous(
|
||||
pass
|
||||
_panel_stop.wait(0.7)
|
||||
|
||||
panel_thread = threading.Thread(target=_panel_loop, name="smolvla2-panel-redraw", daemon=True)
|
||||
panel_thread = threading.Thread(target=_panel_loop, name="pi052-panel-redraw", daemon=True)
|
||||
panel_thread.start()
|
||||
|
||||
try:
|
||||
@@ -1296,11 +1292,11 @@ def _run_autonomous(
|
||||
runtime.state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
else:
|
||||
print(
|
||||
"[smolvla2] no task yet — use /action <your task> to start",
|
||||
"[pi052] no task yet — use /action <your task> to start",
|
||||
flush=True,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("\n[smolvla2] interrupt — stopping", flush=True)
|
||||
print("\n[pi052] interrupt — stopping", flush=True)
|
||||
finally:
|
||||
_panel_stop.set()
|
||||
runtime.stop()
|
||||
@@ -1311,9 +1307,9 @@ def _run_autonomous(
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
robot.disconnect()
|
||||
print("[smolvla2] robot disconnected", flush=True)
|
||||
print("[pi052] robot disconnected", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[smolvla2] WARNING: robot.disconnect raised {exc}", flush=True)
|
||||
print(f"[pi052] WARNING: robot.disconnect raised {exc}", flush=True)
|
||||
|
||||
return 0
|
||||
|
||||
@@ -1340,7 +1336,7 @@ def _make_state_panel_renderer(
|
||||
st = runtime.state
|
||||
run_mode = st.get("mode", "action")
|
||||
mode_tag = "[green]mode: action[/]" if run_mode == "action" else "[yellow]mode: paused[/]"
|
||||
console.rule(f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan")
|
||||
console.rule(f"[bold]PI052[/] · {mode_label} · {mode_tag}", style="cyan")
|
||||
# Always-visible command hint so the operator never has to
|
||||
# remember the slash commands.
|
||||
if run_mode == "action":
|
||||
@@ -1499,14 +1495,14 @@ def main(argv: list[str] | None = None) -> int:
|
||||
autonomous_mode = bool(args.robot_type) and not args.no_robot
|
||||
if autonomous_mode and not args.dataset_repo_id:
|
||||
print(
|
||||
"[smolvla2] ERROR: autonomous robot mode requires --dataset.repo_id "
|
||||
"[pi052] ERROR: autonomous robot mode requires --dataset.repo_id "
|
||||
"for action-denormalisation stats and feature shapes. Pass the "
|
||||
"same dataset the policy was trained on.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||
print(f"[pi052] loading policy from {args.policy_path}", flush=True)
|
||||
policy, preprocessor, postprocessor, ds_meta = _load_policy_and_preprocessor(
|
||||
args.policy_path, args.dataset_repo_id
|
||||
)
|
||||
@@ -1537,7 +1533,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
if chosen:
|
||||
args.task = chosen
|
||||
print(f"[smolvla2] task: {args.task!r}", flush=True)
|
||||
print(f"[pi052] task: {args.task!r}", flush=True)
|
||||
|
||||
# No startup prompts — the runtime is command-driven. It comes up at
|
||||
# the command line in ``paused`` mode (robot idle) unless ``--mode``
|
||||
@@ -1551,7 +1547,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
|
||||
if autonomous_mode:
|
||||
print(
|
||||
f"[smolvla2] connecting to robot.type={args.robot_type} port={args.robot_port}",
|
||||
f"[pi052] connecting to robot.type={args.robot_type} port={args.robot_port}",
|
||||
flush=True,
|
||||
)
|
||||
robot = _build_robot(
|
||||
@@ -1575,7 +1571,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
elif args.dataset_repo_id is not None:
|
||||
print(
|
||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||
f"[pi052] streaming observations from {args.dataset_repo_id} "
|
||||
f"episode={args.dataset_episode} "
|
||||
f"start_frame={args.dataset_start_frame}",
|
||||
flush=True,
|
||||
@@ -1592,11 +1588,11 @@ def main(argv: list[str] | None = None) -> int:
|
||||
|
||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||
print(f"[pi052] tools loaded: {list(tools)}", flush=True)
|
||||
|
||||
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
|
||||
from lerobot.policies.pi052.inference import PI052Runtime # noqa: PLC0415
|
||||
|
||||
runtime = SmolVLA2Runtime(
|
||||
runtime = PI052Runtime(
|
||||
policy=policy,
|
||||
tools=tools,
|
||||
observation_provider=observation_provider,
|
||||
@@ -1662,7 +1658,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
logger.warning("startup tick failed: %s", exc)
|
||||
startup_logs = []
|
||||
for line in startup_logs or []:
|
||||
print(f"[smolvla2] {line}", flush=True)
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||
|
||||
|
||||
@@ -1680,7 +1676,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
|
||||
from rich.console import Console # noqa: PLC0415
|
||||
except ImportError:
|
||||
print(
|
||||
"[smolvla2] rich is required for the interactive REPL. `pip install rich` and re-run.",
|
||||
"[pi052] rich is required for the interactive REPL. `pip install rich` and re-run.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
@@ -1724,7 +1720,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None)
|
||||
# task to be meaningful.
|
||||
if not runtime.state.get("task"):
|
||||
print(
|
||||
"[smolvla2] no task yet — use /action <your task>",
|
||||
"[pi052] no task yet — use /action <your task>",
|
||||
flush=True,
|
||||
)
|
||||
_redraw(last_logs)
|
||||
@@ -13,10 +13,9 @@
|
||||
# limitations under the License.
|
||||
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||
|
||||
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
|
||||
runtime dispatchers consume this when the model emits an assistant
|
||||
message with ``tool_calls=[{function: {name: "say", arguments:
|
||||
{text: ...}}}]``.
|
||||
The first concrete tool implementation. PI052 and downstream runtime
|
||||
dispatchers consume this when the model emits an assistant message
|
||||
with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``.
|
||||
|
||||
Why pocket-tts:
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ def test_messages_vqa_to_loc_noop_without_target_indices():
|
||||
|
||||
|
||||
def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
|
||||
loc = _vqa_answer_to_loc(answer)
|
||||
@@ -175,7 +175,7 @@ def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||||
|
||||
|
||||
def test_loc_round_trip_bbox_preserves_order_and_scale():
|
||||
from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {
|
||||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention-masking tests for the SmolVLA2 text head.
|
||||
|
||||
Regression coverage for the text-CE collapse bug: ``embed_prefix`` flags
|
||||
every language token ``att=0``, which ``make_att_2d_masks`` turns into a
|
||||
single fully *bidirectional* block. Under that mask the text
|
||||
cross-entropy degenerates into a copy task — a supervised target token
|
||||
attends to the tokens it is trained to predict — and the model never
|
||||
learns causal generation, so ``select_message`` collapses at inference.
|
||||
|
||||
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
|
||||
language positions so each target token attends causally among the
|
||||
targets while staying bidirectional to images + the user prompt. These
|
||||
tests pin that behaviour.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# The smolvla2 modeling module imports transformers transitively.
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.smolvla.modeling_smolvla import make_att_2d_masks # noqa: E402
|
||||
from lerobot.policies.smolvla2.modeling_smolvla2 import ( # noqa: E402
|
||||
_locate_lang_range,
|
||||
_mark_target_span_causal,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A synthetic SmolVLA prefix layout: [images, prompt-lang, target-lang, state]
|
||||
#
|
||||
# indices 0-1 : 2 image tokens (att = 0)
|
||||
# indices 2-4 : 3 user-prompt lang (att = 0)
|
||||
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
|
||||
# index 9 : 1 state token (att = 1)
|
||||
#
|
||||
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
|
||||
# real ids on the 4-token target span.
|
||||
# ---------------------------------------------------------------------------
|
||||
N_IMAGE = 2
|
||||
N_PROMPT = 3
|
||||
N_TARGET = 4
|
||||
LANG_START = N_IMAGE
|
||||
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = state-token index
|
||||
PREFIX_LEN = LANG_END + 1
|
||||
|
||||
|
||||
def _embed_prefix_att_masks() -> torch.Tensor:
|
||||
"""Mimic ``embed_prefix``: images + lang all att=0, state att=1."""
|
||||
att = torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
|
||||
att[0, LANG_END] = True # the single state token
|
||||
return att
|
||||
|
||||
|
||||
def _text_labels() -> torch.Tensor:
|
||||
"""-100 over the prompt span, real ids over the target span."""
|
||||
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
|
||||
return labels
|
||||
|
||||
|
||||
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
|
||||
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
|
||||
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
|
||||
return make_att_2d_masks(pad, prefix_att_masks)[0]
|
||||
|
||||
|
||||
def test_locate_lang_range_anchors_on_state_token():
|
||||
"""``_locate_lang_range`` finds the lang span via the lone att=1 token."""
|
||||
lang_start, lang_end = _locate_lang_range(
|
||||
_embed_prefix_att_masks(), num_lang=N_PROMPT + N_TARGET
|
||||
)
|
||||
assert (lang_start, lang_end) == (LANG_START, LANG_END)
|
||||
|
||||
|
||||
def test_mark_sets_att_on_targets_only():
|
||||
"""Only the supervised target language positions flip to att=1."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
expected = [False] * PREFIX_LEN
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
|
||||
expected[i] = True
|
||||
expected[LANG_END] = True # state token, untouched
|
||||
assert marked[0].tolist() == expected
|
||||
|
||||
|
||||
def test_target_tokens_attend_causally_among_themselves():
|
||||
"""A target token must NOT attend to later targets, but must attend
|
||||
to earlier ones — i.e. genuine causal next-token prediction."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
tgt = range(LANG_START + N_PROMPT, LANG_END)
|
||||
for i in tgt:
|
||||
for j in tgt:
|
||||
if j > i:
|
||||
assert not attends[i, j], f"target {i} must not see future target {j}"
|
||||
else:
|
||||
assert attends[i, j], f"target {i} must see earlier/self target {j}"
|
||||
|
||||
|
||||
def test_target_tokens_attend_prompt_and_images_bidirectionally():
|
||||
"""Targets keep full visibility of images + the user prompt."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END):
|
||||
for j in context:
|
||||
assert attends[i, j], f"target {i} must attend context {j}"
|
||||
|
||||
|
||||
def test_action_expert_token_still_sees_full_subtask():
|
||||
"""The state token (action-expert context) attends to every target —
|
||||
causal masking the targets must not hide them from the action path."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
for j in range(LANG_START + N_PROMPT, LANG_END):
|
||||
assert attends[LANG_END, j], f"state token must see target {j}"
|
||||
|
||||
|
||||
def test_non_target_subtask_stays_bidirectional():
|
||||
"""``low_level_execution`` renders the subtask as a user turn — its
|
||||
``text_labels`` are all -100, so the mask must be left untouched and
|
||||
the action expert reads the subtask bidirectionally."""
|
||||
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
|
||||
)
|
||||
assert torch.equal(marked, _embed_prefix_att_masks())
|
||||
|
||||
|
||||
def test_unmarked_mask_is_bidirectional_the_bug():
|
||||
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
|
||||
a target token attends *bidirectionally* to later targets — the
|
||||
text-CE can copy the answer it is trained to predict."""
|
||||
attends = _attends(_embed_prefix_att_masks())
|
||||
first_tgt = LANG_START + N_PROMPT
|
||||
last_tgt = LANG_END - 1
|
||||
assert attends[first_tgt, last_tgt], (
|
||||
"raw embed_prefix mask is bidirectional over language — the first "
|
||||
"target token can see the last, which is the collapse bug"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening.
|
||||
|
||||
``_split_plan_and_say`` (inference) expects the model to emit a textual
|
||||
``<say>...</say>`` marker. ``_flatten_say_tool_calls`` is the training-time
|
||||
serializer that produces it: it rewrites an assistant turn's structured
|
||||
``say`` tool call into that marker *inside the content text*, before
|
||||
``apply_chat_template`` runs — so the chat template only tokenizes plain
|
||||
text and the supervised target span trains the model to emit the marker
|
||||
the runtime parses back. These tests pin the round-trip.
|
||||
"""
|
||||
|
||||
from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls
|
||||
from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say
|
||||
|
||||
|
||||
def _say_call(text):
|
||||
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
|
||||
|
||||
|
||||
def test_flatten_appends_say_marker_and_drops_tool_calls():
|
||||
msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]}
|
||||
out = flatten_say_tool_calls(msg)
|
||||
assert "tool_calls" not in out
|
||||
assert out["content"] == "Pick up the blue cube.\n<say>On it!</say>"
|
||||
|
||||
|
||||
def test_flatten_roundtrips_through_inference_parser():
|
||||
"""The marker the serializer writes must be exactly what the inference
|
||||
parser reads back — this is the train/inference contract."""
|
||||
msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]}
|
||||
flat = flatten_say_tool_calls(msg)["content"]
|
||||
plan, speech = _split_plan_and_say(flat)
|
||||
assert plan == "Move toward the cube."
|
||||
assert speech == "Working on it"
|
||||
|
||||
|
||||
def test_flatten_accepts_json_string_arguments():
|
||||
"""``arguments`` may arrive as a JSON string rather than a dict."""
|
||||
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
|
||||
out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
|
||||
assert out["content"] == "p\n<say>hello there</say>"
|
||||
|
||||
|
||||
def test_flatten_leaves_messages_without_tool_calls_untouched():
|
||||
msg = {"role": "assistant", "content": "just a plan"}
|
||||
assert flatten_say_tool_calls(msg) == msg
|
||||
|
||||
|
||||
def test_flatten_drops_empty_or_non_say_tool_calls():
|
||||
"""A non-``say`` call (or empty text) leaves content alone but still
|
||||
strips the structured calls so the template renders no JSON block."""
|
||||
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
|
||||
out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]})
|
||||
assert out["content"] == "plan only"
|
||||
assert "tool_calls" not in out
|
||||
|
||||
|
||||
def test_flatten_marker_only_when_content_empty():
|
||||
msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]}
|
||||
out = flatten_say_tool_calls(msg)
|
||||
assert out["content"] == "<say>hi</say>"
|
||||
@@ -1,228 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the SmolVLA2 runtime's interactive-VQA helpers.
|
||||
|
||||
Covers camera selection, VQA-answer parsing, and the bounding-box /
|
||||
keypoint overlay drawing — the pure functions, no model load.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.smolvla2.inference.vqa import (
|
||||
answer_has_overlay,
|
||||
available_cameras,
|
||||
camera_short_name,
|
||||
draw_vqa_overlay,
|
||||
observation_image_to_pil,
|
||||
parse_vqa_answer,
|
||||
prompt_camera_choice,
|
||||
)
|
||||
|
||||
PIL = pytest.importorskip("PIL")
|
||||
from PIL import Image # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_available_cameras_extracts_and_sorts_image_keys():
|
||||
observation = {
|
||||
"observation.images.wrist": object(),
|
||||
"observation.state": object(),
|
||||
"observation.images.top": object(),
|
||||
"task": "x",
|
||||
}
|
||||
assert available_cameras(observation) == [
|
||||
"observation.images.top",
|
||||
"observation.images.wrist",
|
||||
]
|
||||
|
||||
|
||||
def test_available_cameras_handles_none_and_empty():
|
||||
assert available_cameras(None) == []
|
||||
assert available_cameras({}) == []
|
||||
|
||||
|
||||
def test_camera_short_name_strips_prefix():
|
||||
assert camera_short_name("observation.images.top") == "top"
|
||||
assert camera_short_name("top") == "top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_single_camera_auto_selects():
|
||||
cams = ["observation.images.top"]
|
||||
# input_fn must never be called for a single-camera setup.
|
||||
chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_number():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.wrist"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_by_name():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None)
|
||||
assert chosen == "observation.images.top"
|
||||
|
||||
|
||||
def test_prompt_camera_choice_invalid_returns_none():
|
||||
cams = ["observation.images.top", "observation.images.wrist"]
|
||||
assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None
|
||||
|
||||
|
||||
def _boom(*_args, **_kwargs):
|
||||
raise AssertionError("input_fn should not be called")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_bbox_answer():
|
||||
answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_keypoint_answer():
|
||||
answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}'
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_count_answer_is_not_an_overlay():
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
assert parsed["kind"] == "count"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_invalid_json_returns_none():
|
||||
assert parse_vqa_answer("not json at all") is None
|
||||
assert parse_vqa_answer("") is None
|
||||
# A JSON array is valid JSON but not a VQA answer object.
|
||||
assert parse_vqa_answer("[1, 2, 3]") is None
|
||||
|
||||
|
||||
def test_parse_unknown_shape():
|
||||
parsed = parse_vqa_answer('{"weird": "payload"}')
|
||||
assert parsed["kind"] == "unknown"
|
||||
assert not answer_has_overlay(parsed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _blank(size=(160, 120)):
|
||||
return Image.new("RGB", size, (0, 0, 0))
|
||||
|
||||
|
||||
def test_draw_bbox_overlay_changes_pixels_and_preserves_size():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer(
|
||||
'{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}'
|
||||
)
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_keypoint_overlay_changes_pixels():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_non_spatial_leaves_image_unchanged():
|
||||
img = _blank()
|
||||
parsed = parse_vqa_answer('{"label": "cubes", "count": 2}')
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.tobytes() == img.tobytes()
|
||||
|
||||
|
||||
def test_draw_overlay_tolerates_malformed_coordinates():
|
||||
img = _blank()
|
||||
# bbox with the wrong arity must not raise.
|
||||
out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}})
|
||||
assert out.size == img.size
|
||||
|
||||
|
||||
def test_observation_image_to_pil_from_batched_float_array():
|
||||
# (1, C, H, W) float array in [0, 1], the runtime observation shape.
|
||||
arr = np.zeros((1, 3, 24, 32), dtype=np.float32)
|
||||
pil = observation_image_to_pil(arr)
|
||||
assert pil.size == (32, 24)
|
||||
assert pil.mode == "RGB"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaliGemma <loc>-format answers (PI052 trains spatial VQA in this vocab)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_loc_keypoint_answer():
|
||||
# <locY><locX> label — y=512/1023≈0.5, x=256/1023≈0.25
|
||||
parsed = parse_vqa_answer("<loc0512><loc0256> blue cube")
|
||||
assert parsed["kind"] == "keypoint"
|
||||
assert parsed["normalized"] is True
|
||||
x, y = parsed["payload"]["point"]
|
||||
assert 0.24 < x < 0.26
|
||||
assert 0.49 < y < 0.51
|
||||
assert parsed["payload"]["label"] == "blue cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_bbox_answer():
|
||||
# <locY0><locX0><locY1><locX1> label
|
||||
parsed = parse_vqa_answer("<loc0100><loc0080><loc0400><loc0360> yellow cube")
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert parsed["normalized"] is True
|
||||
det = parsed["payload"]["detections"][0]
|
||||
x1, y1, x2, y2 = det["bbox"]
|
||||
assert x1 < x2 and y1 < y2
|
||||
assert det["label"] == "yellow cube"
|
||||
assert answer_has_overlay(parsed)
|
||||
|
||||
|
||||
def test_parse_loc_multiple_boxes():
|
||||
answer = "<loc0100><loc0080><loc0400><loc0360> cube ; <loc0200><loc0500><loc0600><loc0900> box"
|
||||
parsed = parse_vqa_answer(answer)
|
||||
assert parsed["kind"] == "bbox"
|
||||
assert len(parsed["payload"]["detections"]) == 2
|
||||
|
||||
|
||||
def test_parse_loc_takes_precedence_over_json():
|
||||
# An answer with <loc> tokens is parsed as loc even if JSON-ish.
|
||||
assert parse_vqa_answer('{"x": <loc0001><loc0002>}')["normalized"] is True
|
||||
|
||||
|
||||
def test_draw_loc_overlay_denormalizes_to_pixels():
|
||||
img = _blank((200, 100))
|
||||
parsed = parse_vqa_answer("<loc0511><loc0511> cube") # ~centre
|
||||
out = draw_vqa_overlay(img, parsed)
|
||||
assert out.size == img.size
|
||||
assert out.tobytes() != img.tobytes()
|
||||
Reference in New Issue
Block a user