diff --git a/examples/port_datasets/slurm_build_robocasa_composite_seen.py b/examples/port_datasets/slurm_build_robocasa_composite_seen.py new file mode 100644 index 000000000..24f33e935 --- /dev/null +++ b/examples/port_datasets/slurm_build_robocasa_composite_seen.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build the merged RoboCasa composite_seen dataset on SLURM via datatrove. + +Two-phase pipeline modeled after ``slurm_port_shards.py`` + +``slurm_aggregate_shards.py``: + +* Phase 1 — DOWNLOAD (parallel, 16 tasks, 1 per worker): + Each datatrove worker downloads one of the 16 RoboCasa composite_seen task + archives (``v1.0/target/composite///lerobot.tar``) via + RoboCasa's own ``download_datasets`` helper. Idempotent — already-extracted + tasks are skipped. Network-bound, CPU-light. + +* Phase 2 — AGGREGATE (single worker, depends on phase 1): + One worker calls ``aggregate_datasets`` over the 16 extracted directories, + producing a single combined LeRobotDataset. Validates fps / robot_type / + features, unifies task indices, concatenates videos + parquet, recomputes + stats. CPU + disk heavy. + +When run under SLURM, phase 2 is submitted with ``depends=phase_1_executor`` +so the scheduler only releases it after every download task succeeds. + +Local (``--slurm 0``) execution runs the two phases sequentially in the +current process — useful for debugging on a workstation. + +Usage on SLURM:: + + uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\ + --output-dir=/scratch/${USER}/robocasa_composite_seen \\ + --hub-repo-id=${HF_USER}/robocasa_composite_seen \\ + --logs-dir=/scratch/${USER}/logs/robocasa \\ + --partition=cpu \\ + --download-cpus=4 --download-mem=8G \\ + --aggregate-cpus=16 --aggregate-mem-per-cpu=4G \\ + --push-to-hub + +Local debug (sequential, single process):: + + uv run python examples/port_datasets/slurm_build_robocasa_composite_seen.py \\ + --output-dir=/tmp/robocasa_composite_seen \\ + --slurm=0 \\ + --tasks=PrepareCoffee,KettleBoiling + +Prereqs: ``robocasa`` + ``robosuite`` installed (see +``docs/source/benchmarks/robocasa.mdx``); ``datatrove`` installed via the +``annotations`` extra (``uv sync --extra annotations``). +""" + +from __future__ import annotations + +import argparse +import logging +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep + +# Reuse the per-task helpers + canonical task list from the single-machine +# script so both runners share one source of truth for "what does it mean to +# download a composite_seen task". The helpers are spelled with leading +# underscores there (module-private), but the slurm runner is a legitimate +# in-tree consumer, so we alias them to clean names here. +from lerobot.scripts.build_robocasa_composite_seen import ( + COMPOSITE_SEEN_TASKS, + _download_task as download_task, + _require_robocasa as require_robocasa, + _resolve_task_root as resolve_task_root, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Pipeline steps +# --------------------------------------------------------------------------- + + +class DownloadRoboCasaTask(PipelineStep): + """Phase 1 — download the task assigned to this rank. + + Each datatrove worker is given a rank in ``[0, world_size)``. With + ``tasks == len(task_names)`` each worker owns exactly one task; with + fewer workers than tasks, datatrove load-balances tasks across workers + using the standard ``rank::world_size`` slicing. + """ + + def __init__(self, task_names: list[str], *, overwrite: bool = False): + super().__init__() + self.task_names = list(task_names) + self.overwrite = overwrite + + def run(self, data=None, rank: int = 0, world_size: int = 1): + from lerobot.utils.utils import init_logging # noqa: PLC0415 + + init_logging() + require_robocasa() + + # Standard datatrove slicing: each rank owns the subset + # ``task_names[rank::world_size]``. When ``world_size == + # len(task_names)`` this is exactly one task per rank. + my_tasks = self.task_names[rank::world_size] + if not my_tasks: + logger.info("rank %d/%d: no tasks assigned", rank, world_size) + return + + for task in my_tasks: + logger.info("rank %d/%d: downloading %s", rank, world_size, task) + root = download_task(task, overwrite=self.overwrite) + logger.info("rank %d/%d: %s extracted at %s", rank, world_size, task, root) + + +class AggregateRoboCasaShards(PipelineStep): + """Phase 2 — merge all 16 extracted directories into one LeRobotDataset. + + ``aggregate_datasets`` parallelizes internally; only rank 0 runs the + merge (mirrors the DROID ``slurm_aggregate_shards.py`` convention). + """ + + def __init__( + self, + task_names: list[str], + *, + output_repo_id: str, + output_dir: Path, + push_to_hub: bool, + private: bool, + ): + super().__init__() + self.task_names = list(task_names) + self.output_repo_id = output_repo_id + self.output_dir = Path(output_dir) + self.push_to_hub = push_to_hub + self.private = private + + def run(self, data=None, rank: int = 0, world_size: int = 1): + from lerobot.utils.utils import init_logging # noqa: PLC0415 + + init_logging() + + if rank != 0: + logger.info("rank %d: aggregation runs on rank 0 only — skipping", rank) + return + + require_robocasa() + from lerobot.datasets import aggregate_datasets # noqa: PLC0415 + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415 + + # Resolve each task's local extraction root. After phase 1 these + # all exist on disk under robocasa.macros.DATASET_BASE_DIR; if any + # are missing, fail loudly so the operator knows phase 1 didn't + # cleanly complete for that task. + roots: list[Path] = [] + missing: list[str] = [] + for task in self.task_names: + root = resolve_task_root(task) + if not root.exists(): + missing.append(f"{task} -> {root}") + else: + roots.append(root) + if missing: + raise RuntimeError( + "Phase 1 did not produce extracted directories for: " + + ", ".join(missing) + + " — re-run the download phase before aggregating." + ) + + # ``aggregate_datasets`` uses ``repo_ids`` purely for logging / + # the unified task table when ``roots=`` is supplied; the actual + # data is loaded from each root directly. + repo_ids = [f"robocasa/{task}_target_human" for task in self.task_names] + + logger.info( + "Aggregating %d source datasets into %s at %s", + len(roots), + self.output_repo_id, + self.output_dir, + ) + aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=self.output_repo_id, + roots=roots, + aggr_root=self.output_dir, + ) + logger.info("Aggregation complete.") + + if self.push_to_hub: + merged = LeRobotDataset( + repo_id=self.output_repo_id, + root=self.output_dir, + ) + logger.info( + "Pushing %s to the Hub (private=%s, %d episodes, %d frames)", + self.output_repo_id, + self.private, + merged.num_episodes, + merged.num_frames, + ) + merged.push_to_hub( + private=self.private, + upload_large_folder=True, + tags=["lerobot", "robocasa", "composite_seen", "manipulation"], + ) + logger.info( + "Push complete: https://huggingface.co/datasets/%s", + self.output_repo_id, + ) + + +# --------------------------------------------------------------------------- +# Executors +# --------------------------------------------------------------------------- + + +def make_download_executor( + *, + task_names: list[str], + overwrite: bool, + job_name: str, + logs_dir: Path, + workers: int, + partition: str | None, + cpus_per_task: int, + mem: str, + time: str, + slurm: bool, +): + """Phase-1 executor: parallel downloads, one task per worker by default.""" + pipeline = [DownloadRoboCasaTask(task_names, overwrite=overwrite)] + logging_dir = str(logs_dir / job_name) + + if slurm: + return SlurmPipelineExecutor( + pipeline=pipeline, + logging_dir=logging_dir, + job_name=job_name, + tasks=len(task_names), # one shard per RoboCasa task + workers=workers, + time=time, + partition=partition, + cpus_per_task=cpus_per_task, + sbatch_args={"mem": mem}, + ) + return LocalPipelineExecutor( + pipeline=pipeline, + logging_dir=logging_dir, + tasks=len(task_names), + workers=min(workers, len(task_names)), + ) + + +def make_aggregate_executor( + *, + task_names: list[str], + output_repo_id: str, + output_dir: Path, + push_to_hub: bool, + private: bool, + job_name: str, + logs_dir: Path, + partition: str | None, + cpus_per_task: int, + mem_per_cpu: str, + time: str, + slurm: bool, + depends: SlurmPipelineExecutor | None, +): + """Phase-2 executor: single worker, aggregates the extracted shards.""" + pipeline = [ + AggregateRoboCasaShards( + task_names, + output_repo_id=output_repo_id, + output_dir=output_dir, + push_to_hub=push_to_hub, + private=private, + ) + ] + logging_dir = str(logs_dir / job_name) + + if slurm: + return SlurmPipelineExecutor( + pipeline=pipeline, + logging_dir=logging_dir, + job_name=job_name, + tasks=1, + workers=1, + time=time, + partition=partition, + cpus_per_task=cpus_per_task, + sbatch_args={"mem-per-cpu": mem_per_cpu}, + depends=depends, # SLURM job dependency on phase 1 + ) + return LocalPipelineExecutor( + pipeline=pipeline, + logging_dir=logging_dir, + tasks=1, + workers=1, + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Build the merged RoboCasa composite_seen LeRobotDataset on SLURM " + "via datatrove (download in parallel, aggregate sequentially)." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # I/O. + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Local directory for the merged dataset (will be created).", + ) + parser.add_argument( + "--hub-repo-id", + type=str, + default=None, + help=( + "Hub repo_id for the merged dataset (e.g. " + "``yourname/robocasa_composite_seen``). Required for " + "``--push-to-hub``; also becomes the merged dataset's " + "canonical ``repo_id``." + ), + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push the merged dataset to the Hub after aggregation.", + ) + parser.add_argument( + "--private", + action="store_true", + help="When pushing, create the Hub repo as private.", + ) + parser.add_argument( + "--logs-dir", + type=Path, + default=Path("./logs/robocasa"), + help="Path to datatrove logs directory (used for stdout/stderr and " + "phase coordination).", + ) + parser.add_argument( + "--tasks", + type=str, + default=None, + help="Comma-separated task names overriding the default 16 " + "composite_seen list (debug / smoke-test).", + ) + parser.add_argument( + "--overwrite-download", + action="store_true", + help="Force re-download even if the local extraction looks complete.", + ) + + # SLURM controls. + parser.add_argument( + "--slurm", + type=int, + default=1, + help="Launch over SLURM (``1``) or locally / sequentially (``0``).", + ) + parser.add_argument( + "--partition", + type=str, + default=None, + help="SLURM partition. A CPU partition is sufficient — no GPU needed.", + ) + + # Phase-1 (download) sizing. + parser.add_argument( + "--download-workers", + type=int, + default=16, + help="Number of parallel SLURM workers for the download phase. " + "Default matches the number of composite_seen tasks (16).", + ) + parser.add_argument( + "--download-cpus", + type=int, + default=4, + help="CPUs per download worker (the work is network- and " + "tar-extraction-bound).", + ) + parser.add_argument( + "--download-mem", + type=str, + default="8G", + help="Total memory per download worker.", + ) + parser.add_argument( + "--download-time", + type=str, + default="06:00:00", + help="SLURM wall-clock limit per download worker (HH:MM:SS).", + ) + + # Phase-2 (aggregate) sizing. + parser.add_argument( + "--aggregate-cpus", + type=int, + default=16, + help="CPUs for the aggregation worker (ffmpeg + parquet I/O parallelize).", + ) + parser.add_argument( + "--aggregate-mem-per-cpu", + type=str, + default="2G", + help="SLURM mem-per-cpu for the aggregation worker.", + ) + parser.add_argument( + "--aggregate-time", + type=str, + default="12:00:00", + help="SLURM wall-clock limit for aggregation (HH:MM:SS). Tens-of-GB " + "merges can take several hours.", + ) + + # Job naming. + parser.add_argument( + "--download-job-name", + type=str, + default="robocasa_dl", + help="SLURM job name for phase 1.", + ) + parser.add_argument( + "--aggregate-job-name", + type=str, + default="robocasa_agg", + help="SLURM job name for phase 2.", + ) + + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="[%(levelname)s] %(name)s: %(message)s", + ) + + tasks = ( + [t.strip() for t in args.tasks.split(",") if t.strip()] + if args.tasks + else list(COMPOSITE_SEEN_TASKS) + ) + if not tasks: + raise SystemExit("No tasks selected.") + + if args.push_to_hub and not args.hub_repo_id: + raise SystemExit("--push-to-hub requires --hub-repo-id.") + + output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen" + slurm = args.slurm == 1 + + logger.info( + "Phase 1 (download) — %d tasks across %d workers (slurm=%s)", + len(tasks), + min(args.download_workers, len(tasks)), + slurm, + ) + download_executor = make_download_executor( + task_names=tasks, + overwrite=args.overwrite_download, + job_name=args.download_job_name, + logs_dir=args.logs_dir, + workers=args.download_workers, + partition=args.partition, + cpus_per_task=args.download_cpus, + mem=args.download_mem, + time=args.download_time, + slurm=slurm, + ) + + logger.info( + "Phase 2 (aggregate) — single worker, output: %s (push_to_hub=%s)", + output_repo_id, + args.push_to_hub, + ) + aggregate_executor = make_aggregate_executor( + task_names=tasks, + output_repo_id=output_repo_id, + output_dir=args.output_dir, + push_to_hub=args.push_to_hub, + private=args.private, + job_name=args.aggregate_job_name, + logs_dir=args.logs_dir, + partition=args.partition, + cpus_per_task=args.aggregate_cpus, + mem_per_cpu=args.aggregate_mem_per_cpu, + time=args.aggregate_time, + slurm=slurm, + depends=download_executor if slurm else None, + ) + + if slurm: + # Submitting the aggregate executor with ``depends=download_executor`` + # also submits the download executor — SlurmPipelineExecutor walks + # the dependency chain and submits each job once with the right + # ``--dependency=afterok:`` arg. + aggregate_executor.run() + else: + # Local sequential: run download to completion, then aggregate. + download_executor.run() + aggregate_executor.run() + + logger.info("Done. Merged dataset at %s.", args.output_dir) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index 0a7ced215..22b851ac7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -308,12 +308,8 @@ lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" lerobot-annotate="lerobot.scripts.lerobot_annotate:main" -# Interactive hierarchical-VLA runtime. The same entry point drives both -# SmolVLA2 (SmolVLM2 backbone) and PI052 (PaliGemma backbone) — the -# policy type is read from the checkpoint. ``lerobot-pi052-runtime`` is -# an alias so the command name isn't misleading for PI052 users. -lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main" -lerobot-pi052-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main" +# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone). +lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.package-data] diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py index 79a7f1614..d8f473f6b 100644 --- a/src/lerobot/annotations/steerable_pipeline/executor.py +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -129,8 +129,8 @@ class Executor: written = self.writer.write_all(records, staging_dir, root) print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True) - # Persist the tool catalog to meta/info.json so chat-template - # consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read + # Persist the tool catalog to meta/info.json so downstream + # consumers (PI052 / Pi0.5 / dataset visualizer) can read # it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and # additive: anything the user pre-populated is preserved; we only # ensure the canonical ``say`` schema is present. diff --git a/src/lerobot/configs/recipes/subtask_mem.yaml b/src/lerobot/configs/recipes/subtask_mem.yaml index 3dbed98ff..6903b3585 100644 --- a/src/lerobot/configs/recipes/subtask_mem.yaml +++ b/src/lerobot/configs/recipes/subtask_mem.yaml @@ -20,11 +20,10 @@ # bindings are missing simply don't render for that sample, so a # dataset without interjections still trains the rest of the blend. # -# SmolVLA2 note: the `say` tool call on the interjection-response turn -# is flattened to a `...` text marker by the chat tokenizer -# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM -# head learns to emit exactly the marker the runtime parses back -# (`_split_plan_and_say`). +# Tool-call note: the `say` tool call on the interjection-response turn +# is flattened to a `...` text marker by the tokenizer step +# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the +# marker the runtime parses back (`_split_plan_and_say`). blend: diff --git a/src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml b/src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml index 366dcaa16..2cd1e7ae5 100644 --- a/src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml +++ b/src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml @@ -20,11 +20,10 @@ # bindings are missing simply don't render for that sample, so a # dataset without interjections still trains the rest of the blend. # -# SmolVLA2 note: the `say` tool call on the interjection-response turn -# is flattened to a `...` text marker by the chat tokenizer -# (`_flatten_say_tool_calls`) before `apply_chat_template`, so the LM -# head learns to emit exactly the marker the runtime parses back -# (`_split_plan_and_say`). +# Tool-call note: the `say` tool call on the interjection-response turn +# is flattened to a `...` text marker by the tokenizer step +# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the +# marker the runtime parses back (`_split_plan_and_say`). blend: diff --git a/src/lerobot/configs/recipes/subtasks_vqa.yaml b/src/lerobot/configs/recipes/subtasks_vqa.yaml index 1002b3b8e..48a6ced54 100644 --- a/src/lerobot/configs/recipes/subtasks_vqa.yaml +++ b/src/lerobot/configs/recipes/subtasks_vqa.yaml @@ -1,5 +1,4 @@ -# subtasks_vqa — Hi-Robot blend, shared between SmolVLA2 (SmolVLM2 -# backbone) and PI052 (PaliGemma backbone). +# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone). # # Trains two things only: subtasks and VQA. Plan and memory are # intentionally left out — keeps the prompt short and the training @@ -10,9 +9,8 @@ # low_level_execution — flow loss with [images, subtask, state]. # ask_vqa_{top,wrist} — camera-grounded VQA. # -# Each backbone's text tokenizer renders these messages differently -# (SmolVLA2 uses the chat template; PI052 concatenates as plain -# ``Role: content`` text), but the recipe spec is identical. +# PI052's text tokenizer renders these messages as plain +# ``Role: content`` text (PaliGemma is not chat-pretrained). blend: diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 590d8f098..3de801ecf 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -27,7 +27,6 @@ from .sac.configuration_sac import SACConfig as SACConfig from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig from .sarm.configuration_sarm import SARMConfig as SARMConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig -from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .utils import make_robot_action, prepare_observation_for_inference from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig @@ -52,7 +51,6 @@ __all__ = [ "SACConfig", "SARMConfig", "SmolVLAConfig", - "SmolVLA2Config", "TDMPCConfig", "VQBeTConfig", "WallXConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 777f67f84..6f0447c4e 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -144,10 +144,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy - elif name == "smolvla2": - from .smolvla2.modeling_smolvla2 import SmolVLA2Policy - - return SmolVLA2Policy elif name == "sarm": from .sarm.modeling_sarm import SARMRewardModel @@ -180,8 +176,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", - "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", - "smolvla", "reward_classifier", "wall_x". + "multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", + "pi052", "sac", "smolvla", "reward_classifier", "wall_x". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -212,10 +208,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SACConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) - elif policy_type == "smolvla2": - from .smolvla2.configuration_smolvla2 import SmolVLA2Config - - return SmolVLA2Config(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) elif policy_type == "groot": @@ -424,17 +416,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif policy_cfg.type == "smolvla2": - # NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch - # MUST come before the SmolVLAConfig isinstance check below - # (otherwise SmolVLA2 would silently pick up SmolVLA's processor). - from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors - - processors = make_smolvla2_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - elif isinstance(policy_cfg, SmolVLAConfig): from .smolvla.processor_smolvla import make_smolvla_pre_post_processors diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 4214baba7..5b4c25924 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -32,8 +32,8 @@ This is the dual-head co-training pattern from the paper: with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits inference into a text-prediction step followed by an action-prediction -step, which mirrors what ``SmolVLA2Runtime`` already does on a -SmolVLM2 backbone. +step, which the multi-rate ``PI052Runtime`` (in +``lerobot.policies.pi052.inference``) drives at separate rates. """ from dataclasses import dataclass @@ -48,10 +48,9 @@ from ..pi05.configuration_pi05 import PI05Config class PI052Config(PI05Config): """π0.5 with the PaliGemma LM head re-enabled for subtask prediction. - See ``SmolVLA2Config`` for the analogous SmolVLM2-backed dual-head - config. Same recipe-driven training surface; the only difference is - which backbone the policy uses (PaliGemma here vs SmolVLM2 there). - The flow:text loss split is the milder 5:1 (see ``flow_loss_weight``). + Recipe-driven dual-head training: the flow head supervises actions, + the LM head supervises subtask / plan / memory / VQA text. The + flow:text loss split is the milder 5:1 (see ``flow_loss_weight``). """ # Recipe / language stack --------------------------------------------- @@ -64,17 +63,17 @@ class PI052Config(PI05Config): apply_chat_template: bool = False """PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a - chat template. So unlike SmolVLA2 we don't apply one. The recipe - renderer's output is concatenated as a plain prefix + assistant - suffix instead, mirroring how the π0.5 paper's high-level inference - samples text auto-regressively after the prefix.""" + chat template, so we don't apply one. The recipe renderer's output + is concatenated as a plain prefix + assistant suffix instead, + mirroring how the π0.5 paper's high-level inference samples text + auto-regressively after the prefix.""" # Loss weights -------------------------------------------------------- # Paper §IV.D uses α=10 between the flow and text terms, assuming # text is a rare auxiliary task. With the recipe stack the flow-only # `low_level` branch fires on a large share of samples, so α=10 # swamps the LM head and collapses generation into degenerate - # repetition. We use the milder 5:1 split (matches SmolVLA2Config). + # repetition. We use the milder 5:1 split here. text_loss_weight: float = 1.0 """Weight on the LM-head cross-entropy term. Set to ``0`` to disable text training entirely (reverts to flow-only / π0.5 behaviour).""" @@ -93,8 +92,8 @@ class PI052Config(PI05Config): hierarchical inference.""" # Per-component prompt dropout (Pi0.7 §V.E) --------------------------- - # Same regulariser surface as SmolVLA2: randomly drop non-target - # context messages so the LM head learns to handle missing / + # Randomly drop non-target context messages so the LM head learns + # to handle missing / # stale plan / memory at inference. Defaults to 0.0 so behaviour # is identical until explicitly enabled. plan_dropout_prob: float = 0.0 diff --git a/src/lerobot/policies/smolvla2/inference/__init__.py b/src/lerobot/policies/pi052/inference/__init__.py similarity index 91% rename from src/lerobot/policies/smolvla2/inference/__init__.py rename to src/lerobot/policies/pi052/inference/__init__.py index 30f77635a..10f7f4726 100644 --- a/src/lerobot/policies/smolvla2/inference/__init__.py +++ b/src/lerobot/policies/pi052/inference/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""SmolVLA2 inference / runtime orchestration. +"""PI052 inference / runtime orchestration. Multi-rate runtime that mirrors the recipe-time training shape: @@ -22,12 +22,12 @@ Multi-rate runtime that mirrors the recipe-time training shape: ask_vqa_* → AskVQAFwd (event: stdin question) speech tool calls → DispatchToolCalls (event: tool_call_pending) -The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and -calls ``run()``. +The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls +``run()``. """ from .repl import StdinReader -from .runtime import SmolVLA2Runtime +from .runtime import PI052Runtime from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event from .steps import ( AskVQAFwd, @@ -44,7 +44,7 @@ from .ui import make_state_panel, print_robot_lines, print_user_line __all__ = [ # runtime - "SmolVLA2Runtime", + "PI052Runtime", "StdinReader", # state helpers "initial_runtime_state", diff --git a/src/lerobot/policies/smolvla2/inference/repl.py b/src/lerobot/policies/pi052/inference/repl.py similarity index 95% rename from src/lerobot/policies/smolvla2/inference/repl.py rename to src/lerobot/policies/pi052/inference/repl.py index 671de4971..2b8813f58 100644 --- a/src/lerobot/policies/smolvla2/inference/repl.py +++ b/src/lerobot/policies/pi052/inference/repl.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Stdin REPL event collector for the SmolVLA2 runtime. +"""Stdin REPL event collector for the PI052 runtime. Reads non-blocking stdin lines, classifies each one heuristically: @@ -23,7 +23,7 @@ Reads non-blocking stdin lines, classifies each one heuristically: Plugged into the runtime via ``event_collector=StdinReader().poll``. -Note: the shipped CLI (``lerobot-smolvla2-runtime``) drives stdin +Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin directly in its REPL / autonomous loops and does *not* wire this collector; it's kept as the documented embedding hook and for tests. """ @@ -92,7 +92,7 @@ class StdinReader: if not state.get("task"): task = line[5:].strip() if lower.startswith("task:") else line state["task"] = task - print(f"[smolvla2] Task: {task}", flush=True) + print(f"[pi052] Task: {task}", flush=True) self._seen_first_line = True return diff --git a/src/lerobot/policies/smolvla2/inference/runtime.py b/src/lerobot/policies/pi052/inference/runtime.py similarity index 98% rename from src/lerobot/policies/smolvla2/inference/runtime.py rename to src/lerobot/policies/pi052/inference/runtime.py index 6605f72cb..3ca3e0576 100644 --- a/src/lerobot/policies/smolvla2/inference/runtime.py +++ b/src/lerobot/policies/pi052/inference/runtime.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""SmolVLA2 runtime loop. +"""PI052 runtime loop. Threads the multi-rate inference pipeline together with a stdin REPL event collector, drives ticks through :class:`TickClock`, and prints @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) @dataclass -class SmolVLA2Runtime: +class PI052Runtime: """Compose the inference pipeline and drive it tick-by-tick.""" policy: Any @@ -195,11 +195,11 @@ class SmolVLA2Runtime: def _flush_logs(self) -> None: for line in self.state.get("log_lines") or []: - print(f"[smolvla2] {line}", flush=True) + print(f"[pi052] {line}", flush=True) def _on_shutdown(self) -> None: # Drain any queued action chunks safely. queue = self.state.get("action_queue") if isinstance(queue, deque): queue.clear() - print("[smolvla2] runtime stopped", flush=True) + print("[pi052] runtime stopped", flush=True) diff --git a/src/lerobot/policies/smolvla2/inference/runtime_state.py b/src/lerobot/policies/pi052/inference/runtime_state.py similarity index 100% rename from src/lerobot/policies/smolvla2/inference/runtime_state.py rename to src/lerobot/policies/pi052/inference/runtime_state.py diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/pi052/inference/steps.py similarity index 87% rename from src/lerobot/policies/smolvla2/inference/steps.py rename to src/lerobot/policies/pi052/inference/steps.py index f5d485b23..d205cc6e7 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/pi052/inference/steps.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference steps for the SmolVLA2 multi-rate runtime. +"""Inference steps for the PI052 multi-rate runtime. Each step is a tiny class with a ``trigger`` and an ``__call__(state)``; the runtime applies them in order each tick. When a step's trigger @@ -98,7 +98,7 @@ class LowLevelForward(InferenceStep): if not state.get("task"): return None - # SmolVLA produces *action chunks* (typically 50 steps via + # PI052 produces *action chunks* (typically 50 steps via # flow-matching). Every step gets dispatched to the robot; # popping one per dispatch tick is essentially free. Only # generate a new chunk once the previous one has fully @@ -256,34 +256,11 @@ def _build_text_batch( ) -> dict[str, Any]: """Tokenize chat messages into the batch ``select_message`` expects. - Dispatches on the policy backbone so one runtime drives both: - - * ``smolvla2`` (SmolVLM2) — chat template via ``apply_chat_template``. - * ``pi052`` (PaliGemma) — flat ``Role: content`` text, since - PaliGemma is not chat-pretrained (mirrors ``PI052TextTokenizerStep``). - """ - if getattr(getattr(policy, "config", None), "type", "") == "pi052": - return _build_text_batch_pi052( - policy, prompt_messages, add_generation_prompt=add_generation_prompt - ) - return _build_text_batch_chat( - policy, prompt_messages, add_generation_prompt=add_generation_prompt - ) - - -def _build_text_batch_pi052( - policy: Any, - prompt_messages: list[dict[str, Any]], - *, - add_generation_prompt: bool = True, -) -> dict[str, Any]: - """PI052 text batch — flat ``User: … \\nAssistant: …`` prompt. - - PaliGemma ships no chat template, so PI052 trains on the plain - role-prefixed concatenation built by ``PI052TextTokenizerStep``. - Reuses that exact formatter so the inference prefix matches - training. ``add_generation_prompt`` appends the bare ``Assistant: `` - header the LM head continues from. + PI052's backbone (PaliGemma) ships no chat template, so we train on + a plain role-prefixed concatenation built by + ``PI052TextTokenizerStep``. We reuse that exact formatter so the + inference prefix matches training; ``add_generation_prompt`` appends + the bare ``Assistant: `` header the LM head continues from. """ import torch # noqa: PLC0415 from transformers import AutoTokenizer # noqa: PLC0415 @@ -315,99 +292,6 @@ def _build_text_batch_pi052( if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool: attn = attn.bool() - device = getattr(getattr(policy, "config", None), "device", None) - if device is not None: - try: - ids = ids.to(device) - if attn is not None and hasattr(attn, "to"): - attn = attn.to(device) - except Exception as exc: # noqa: BLE001 - logger.debug("could not move pi052 lang tokens to %s: %s", device, exc) - return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer} - - -def _build_text_batch_chat( - policy: Any, - prompt_messages: list[dict[str, Any]], - *, - add_generation_prompt: bool = True, -) -> dict[str, Any]: - """SmolVLA2 (SmolVLM2) text batch — chat-template tokenization. - - Reuses ``_strip_lerobot_blocks`` so the inference prompt shape - matches the training-time chat tokenizer step exactly. - """ - from transformers import AutoTokenizer # noqa: PLC0415 - - tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name) - if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: - tokenizer.pad_token = tokenizer.eos_token - - # Reuse the *exact* normaliser that the training-time chat - # tokenizer step uses (``_strip_lerobot_blocks``). It handles all - # the cases the SmolVLM chat template expects: - # * ``content: list[block]`` → keep text blocks, drop images - # * ``content: None`` → ``[{type: text, text: ""}]`` - # * ``content: str`` / anything else → ``[{type: text, text: str(content)}]`` - # Doing it any other way creates a training/inference mismatch in - # exactly the prompt shape the model was supervised on. Also - # strips ``stream`` / ``target`` recipe metadata. - from lerobot.policies.smolvla2.chat_processor_smolvla2 import ( # noqa: PLC0415 - _strip_lerobot_blocks, - ) - - text_messages = [_strip_lerobot_blocks(m) for m in prompt_messages] - encoded = tokenizer.apply_chat_template( - text_messages, - add_generation_prompt=add_generation_prompt, - tokenize=True, - return_tensors="pt", - ) - # ``apply_chat_template`` can return any of: - # - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers), - # - a list[int] / list[list[int]] (when ``return_tensors`` is ignored), - # - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask`` - # (newer transformers, especially via processor.apply_chat_template - # forwarding through here). - # Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's - # attention mask when available so we don't have to re-derive it - # from ``pad_token_id`` (which can be ``None`` for SmolVLM). - attn: Any = None - if hasattr(encoded, "input_ids"): - ids = encoded.input_ids - attn = getattr(encoded, "attention_mask", None) - elif isinstance(encoded, dict) and "input_ids" in encoded: - ids = encoded["input_ids"] - attn = encoded.get("attention_mask") - else: - ids = encoded - if isinstance(ids, list): - if ids and isinstance(ids[0], list): - ids = ids[0] - import torch # noqa: PLC0415 - - ids = torch.tensor(ids, dtype=torch.long) - if hasattr(ids, "ndim") and ids.ndim == 1: - ids = ids.unsqueeze(0) - if attn is None and tokenizer.pad_token_id is not None: - attn = ids != tokenizer.pad_token_id - elif isinstance(attn, list): - import torch # noqa: PLC0415 - - attn = torch.tensor(attn, dtype=torch.long) - if attn.ndim == 1: - attn = attn.unsqueeze(0) - # SmolVLA's ``eager_attention_forward`` does - # ``torch.where(attention_mask[..., None, :, :], ...)`` which - # requires a *bool* condition tensor; ``BatchEncoding``'s - # attention_mask is typically Long (0/1). Cast so the prefix - # forward doesn't blow up with ``where expected condition to be a - # boolean tensor, but got a tensor with dtype Long``. - if attn is not None and hasattr(attn, "dtype"): - import torch as _torch # noqa: PLC0415 - - if attn.dtype != _torch.bool: - attn = attn.bool() # Move tokens onto the policy's device — otherwise prefix embedding # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA # model), which the caller's broad except would swallow silently. @@ -418,7 +302,7 @@ def _build_text_batch_chat( if attn is not None and hasattr(attn, "to"): attn = attn.to(device) except Exception as exc: # noqa: BLE001 - logger.debug("could not move lang tokens to %s: %s", device, exc) + logger.debug("could not move pi052 lang tokens to %s: %s", device, exc) return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer} @@ -1022,13 +906,7 @@ def _generate_with_policy( "temperature": temperature, "top_p": top_p, } - # Only pass ``suppress_loc_tokens`` to backbones that accept it - # (pi052). SmolVLA2's ``select_message`` does not, so we omit - # the kwarg there to avoid breaking the shared runtime. - import inspect # noqa: PLC0415 - - if "suppress_loc_tokens" in inspect.signature(policy.select_message).parameters: - kwargs["suppress_loc_tokens"] = suppress_loc_tokens + kwargs["suppress_loc_tokens"] = suppress_loc_tokens return policy.select_message(batch, **kwargs) except Exception as exc: # noqa: BLE001 logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG)) diff --git a/src/lerobot/policies/smolvla2/inference/triggers.py b/src/lerobot/policies/pi052/inference/triggers.py similarity index 98% rename from src/lerobot/policies/smolvla2/inference/triggers.py rename to src/lerobot/policies/pi052/inference/triggers.py index b4646261b..dd9537f66 100644 --- a/src/lerobot/policies/smolvla2/inference/triggers.py +++ b/src/lerobot/policies/pi052/inference/triggers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Trigger primitives for SmolVLA2's multi-rate inference runtime. +"""Trigger primitives for PI052's multi-rate inference runtime. Mirrors the plan's Section "Runtime orchestration": each ``InferenceStep`` is gated by a :class:`Trigger` that decides per tick diff --git a/src/lerobot/policies/smolvla2/inference/ui.py b/src/lerobot/policies/pi052/inference/ui.py similarity index 97% rename from src/lerobot/policies/smolvla2/inference/ui.py rename to src/lerobot/policies/pi052/inference/ui.py index 0d98c5897..a2d90f076 100644 --- a/src/lerobot/policies/smolvla2/inference/ui.py +++ b/src/lerobot/policies/pi052/inference/ui.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Rich-based REPL layout for the SmolVLA2 runtime. +"""Rich-based REPL layout for the PI052 runtime. Two-zone terminal layout: @@ -97,7 +97,7 @@ def make_state_panel(state: dict[str, Any]) -> Any: ) return Panel( table, - title=f"[bold]SmolVLA2 state[/] · mode: {mode_tag}", + title=f"[bold]PI052 state[/] · mode: {mode_tag}", border_style="cyan", ) diff --git a/src/lerobot/policies/smolvla2/inference/vqa.py b/src/lerobot/policies/pi052/inference/vqa.py similarity index 99% rename from src/lerobot/policies/smolvla2/inference/vqa.py rename to src/lerobot/policies/pi052/inference/vqa.py index 0992a0a12..6561a25e2 100644 --- a/src/lerobot/policies/smolvla2/inference/vqa.py +++ b/src/lerobot/policies/pi052/inference/vqa.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Interactive VQA for the SmolVLA2 runtime. +"""Interactive VQA for the PI052 runtime. In ``/vlm`` mode a typed line is treated as a VQA question. This module runs the full interactive flow: @@ -379,7 +379,7 @@ def handle_vqa_query( # Feed the FULL observation (every camera + state) to the VLM. The # ``ask_vqa_*`` recipes look single-camera, but the image *block* is # stripped before tokenization — the actual frames reach the model - # via SmolVLA's ``OBS_IMAGES_*`` channels, and ``embed_prefix`` + # via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix`` # consumes *all* ``config.image_features`` regardless of which # camera the sub-recipe was tagged for. So the model always sees # every camera; the operator never has to name one to ask. diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 5a055a5d5..1f431b26c 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -29,10 +29,10 @@ A thin subclass of :class:`PI05Policy` that: with α controllable via ``config.flow_loss_weight``. -The same multi-rate runtime that drives ``SmolVLA2Runtime`` (see -``lerobot.policies.smolvla2.inference``) can drive this policy too — -both expose ``predict_action_chunk`` for the action expert and -``select_message`` for the LM head. +The multi-rate inference runtime in ``lerobot.policies.pi052.inference`` +(driven by the ``lerobot-pi052-runtime`` CLI) sits on top of this: +``predict_action_chunk`` for the action expert and ``select_message`` +for the LM head. """ from __future__ import annotations @@ -380,9 +380,9 @@ class PI052Policy(PI05Policy): for p in backbone.lm_head.parameters(): p.requires_grad_(True) # The text model's final norm and last transformer block — - # mirror SmolVLA2's logic, which finds these dynamically by - # the trainable=False parameters that point at the head's - # neighbourhood. + # find them dynamically by walking up from the LM head so we + # don't hard-code module names that may drift across transformers + # versions. text_model = getattr(backbone, "model", None) text_model = getattr(text_model, "language_model", text_model) if text_model is None: @@ -934,9 +934,8 @@ class PI052Policy(PI05Policy): ) -> str: """Generate text continuation from a multimodal prefix. - Mirrors ``SmolVLA2Policy.select_message`` so the same - :class:`lerobot.policies.smolvla2.inference.SmolVLA2Runtime` - can drive π0.5 v2 unchanged. + Consumed by :class:`lerobot.policies.pi052.inference.PI052Runtime` + for the high-level / VQA / memory-update text streams. ``suppress_loc_tokens`` masks PaliGemma's reserved ```` ids ([256000, 257024)) to ``-inf`` before sampling. PaliGemma's diff --git a/src/lerobot/policies/pi052/processor_pi052.py b/src/lerobot/policies/pi052/processor_pi052.py index f7ec21d0a..e95571054 100644 --- a/src/lerobot/policies/pi052/processor_pi052.py +++ b/src/lerobot/policies/pi052/processor_pi052.py @@ -182,8 +182,7 @@ def _load_recipe(path_str: str) -> TrainingRecipe: """Resolve ``path_str`` to a ``TrainingRecipe``. Accepts an absolute path or a path relative to - ``src/lerobot/configs/`` (same lookup rules as - ``make_smolvla2_pre_post_processors``). + ``src/lerobot/configs/``. """ p = Path(path_str) if not p.is_absolute() and not p.exists(): diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index 18c926a0b..bf0d7739c 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -14,7 +14,7 @@ """π0.5 v2 text-tokenisation step. -PaliGemma is *not* chat-pretrained, so unlike SmolVLA2 we can't lean on +PaliGemma is *not* chat-pretrained, so we can't lean on ``tokenizer.apply_chat_template``. Instead we concatenate the rendered messages as plain text with simple ``User: ... Assistant: ...`` role delimiters — matching the prompt format π0.5 uses in the paper @@ -30,8 +30,7 @@ Outputs: ``target_message_indices``. ``modeling_pi052`` runs cross-entropy on those positions via the PaliGemma ``lm_head``. * ``predict_actions`` — bool tensor, ``True`` iff any of the rendered - target messages has ``message_streams[i] == "low_level"``. Same - semantics as the SmolVLA2 step. + target messages has ``message_streams[i] == "low_level"``. """ from __future__ import annotations @@ -54,11 +53,9 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the -# matching SmolVLA2 implementation. Behaviour: when -# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on -# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans -# the LM head will be supervised on. +# Debug helper — when ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N +# samples processed (on rank 0) are pretty-printed with ``[TGT]...[/TGT]`` +# markers over the spans the LM head will be supervised on. # --------------------------------------------------------------------------- _DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0")) @@ -246,8 +243,8 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]: # values exceeding the camera's pixel dimensions — they're not pixels.) # Converting to ```` is therefore camera-resolution-independent: # ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here — -# not in the dataset — so the dataset stays backbone-agnostic (SmolVLA2 -# keeps the JSON). +# not in the dataset — so the dataset keeps the raw JSON and stays +# backbone-agnostic. # --------------------------------------------------------------------------- # The 0–1000 scale Qwen2.5-VL emits for grounding coordinates. @@ -424,8 +421,8 @@ def _format_messages( class PI052TextTokenizerStep(ProcessorStep): """Render messages → token ids + label mask + predict_actions flag. - π0.5 analogue of ``SmolVLA2ChatTokenizerStep``. No chat template; - concatenates messages as ``User: ... \\nAssistant: ...`` text. + No chat template; concatenates messages as + ``User: ... \\nAssistant: ...`` text. """ tokenizer_name: str = "google/paligemma-3b-pt-224" @@ -613,8 +610,7 @@ class PI052TextTokenizerStep(ProcessorStep): continue labels[token_pos] = input_ids[token_pos] - # Scan ALL message streams (not just targets) — see - # ``chat_processor_smolvla2.py`` for rationale: the v2 + # Scan ALL message streams (not just targets): the # ``low_level_execution`` recipe drops ``target: true`` on # the assistant to avoid trivial copy-from-user text-CE; the # flow loss still needs to fire, gated by ``stream: low_level``. @@ -686,7 +682,7 @@ class PI052TextTokenizerStep(ProcessorStep): def _classify_for_dropout(message: dict[str, Any]) -> str | None: - """Heuristic content-prefix classifier — mirrors SmolVLA2's.""" + """Heuristic content-prefix classifier (plan / memory / subtask).""" content = message.get("content") if isinstance(content, list): text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"] diff --git a/src/lerobot/policies/smolvla2/__init__.py b/src/lerobot/policies/smolvla2/__init__.py deleted file mode 100644 index 4e72870a8..000000000 --- a/src/lerobot/policies/smolvla2/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled. - -SmolVLA strips the LM head from the SmolVLM backbone because it only does -flow-matching action prediction. SmolVLA2 keeps the LM head so the same -model can train on the full Hi Robot / MEM / ECoT message blend defined in -the steerable annotation plan (PR1 + PR2): - -* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss -* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``, - ``user_interjection_response``, ``high_level_subtask``) → CE loss on - ``lm_head`` over the recipe's target message tokens -* mixed sub-recipes → both losses summed (weighted) - -The ``predict_actions`` toggle follows the Pi0.5 convention from Section -I.7 of the plan: ``True`` if any ``low_level`` target is present in the -sample, else ``False``. - -This package is a thin subclass of ``lerobot.policies.smolvla`` so most of -the model code stays in one place — only the dual-loss path and the -chat-template processor live here. -""" - -from .configuration_smolvla2 import SmolVLA2Config - -__all__ = ["SmolVLA2Config"] diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py deleted file mode 100644 index b628d54c5..000000000 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ /dev/null @@ -1,649 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SmolVLA2's chat-template tokenization step. - -Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a -``recipe_path`` is set. Reads the rendered messages produced by -``RenderMessagesStep`` (PR 1) and produces: - -* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — - the chat-templated prompt tokenized by SmolVLM's tokenizer, with - ``tools=meta.tools`` (PR 1's catalog). -* ``text_labels`` — same shape as token ids, ``-100`` everywhere except - the positions belonging to messages whose index is in - ``target_message_indices``. The next commit's modeling forward path - applies cross-entropy on those positions via the SmolVLM ``lm_head``. -* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered - target messages has ``message_streams[i] == "low_level"``. The - modeling forward uses this to gate the flow head. - -Image / video content blocks in the rendered messages are dropped -before tokenization — the chat template only handles text, and SmolVLA -already passes camera tensors out-of-band via the standard -``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged -(``embed_prefix`` puts image embeddings before language embeddings, -matching the chat-template-stripped text order). -""" - -from __future__ import annotations - -import logging -import os -from dataclasses import dataclass -from typing import Any - -import torch - -from lerobot.configs import PipelineFeatureType, PolicyFeature -from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry -from lerobot.types import EnvTransition, TransitionKey -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Debug helper: dump the first N rendered samples to stdout so you can sanity- -# check what the model actually sees before kicking off a long training run. -# -# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ... -# -# Prints the recipe-rendered messages, the chat-templated text (decoded back -# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans -# are supervised by text-CE. Stops after N total dumps to keep training logs -# tractable. Rank-0 only when accelerate sets ``RANK``. -# --------------------------------------------------------------------------- - -_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0")) -_DUMPED_SO_FAR = 0 - - -def _is_dump_rank() -> bool: - rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0" - try: - return int(rank) == 0 - except ValueError: - return True - - -def _dump_recipe_sample( - *, - messages: list[dict[str, Any]], - token_ids: list[int], - labels: list[int], - predict_actions: bool, - tokenizer: Any, -) -> None: - """Pretty-print one rendered sample. Stops once the global budget is hit.""" - global _DUMPED_SO_FAR - if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank(): - return - _DUMPED_SO_FAR += 1 - - decoded = tokenizer.decode(token_ids, skip_special_tokens=False) - parts: list[str] = [] - i = 0 - while i < len(labels): - if labels[i] == -100: - j = i - while j < len(labels) and labels[j] == -100: - j += 1 - parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False)) - i = j - else: - j = i - while j < len(labels) and labels[j] != -100: - j += 1 - tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False) - parts.append(f"[TGT]{tgt_text}[/TGT]") - i = j - annotated = "".join(parts) - - n_tgt = sum(1 for l in labels if l != -100) - print( - "\n========== RECIPE SAMPLE DUMP " - f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========", - flush=True, - ) - print(f" predict_actions: {predict_actions}", flush=True) - print(f" rendered messages ({len(messages)}):", flush=True) - for m in messages: - stream = m.get("stream") - target = m.get("target") - role = m.get("role") - content = m.get("content") - print(f" - role={role} stream={stream} target={target}", flush=True) - print(f" content: {content!r}", flush=True) - print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True) - print(f" decoded (raw):\n {decoded}", flush=True) - print(f" decoded (with target markers):\n {annotated}", flush=True) - print("==============================================\n", flush=True) - - -@dataclass -@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer") -class SmolVLA2ChatTokenizerStep(ProcessorStep): - """Render messages → token ids + label mask + predict_actions flag. - - This is the bridge between the recipe stack (PR 1's - ``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward - (next commit, which reads ``text_labels`` / ``predict_actions``). - Pure-text turns and multi-stream targets are both handled. - """ - - tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" - max_length: int = 2048 - padding: str = "longest" - padding_side: str = "right" - tools: list[dict[str, Any]] | None = None - # --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up - # ``feat/pi05-prompt-dropout``). At training, drop non-target - # messages whose content was substituted from the named recipe - # binding with the given probability. Forces the model to handle - # missing context at inference — directly attacks the memorisation - # collapse where ``current_subtask=""`` puts the prompt OOD. All - # default to 0.0 (no dropout) so behaviour is identical until - # explicitly opted in via the training config. - plan_dropout_prob: float = 0.0 - memory_dropout_prob: float = 0.0 - subtask_dropout_prob: float = 0.0 - interjection_dropout_prob: float = 0.0 - # Optional seed for the per-sample RNG. ``None`` ⇒ use - # ``sample_idx`` derived from the transition (when present), so - # dropout is reproducible across runs but varies per sample. - dropout_seed: int | None = None - - def __post_init__(self) -> None: - # Lazy: don't load the tokenizer until the step actually runs, - # so unit tests that import the module without transformers - # installed still pass. - self._tokenizer: Any = None - if self.tools is None: - # Default: no tools rendered into the system prompt. The - # ``say()`` tool was only used by the now-removed - # ``user_interjection_response`` recipe; including its - # schema on every sample adds a long system message to - # the action expert's prefix and creates a train/inference - # mismatch (the inference low-level loop doesn't pass - # tools=, so the chat template doesn't render them). - # Users who actually need tools can set them via - # ``with_tools(meta.tools)``. - self.tools = [] - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep": - """Override the tools catalog rendered into the system prompt.""" - self.tools = list(tools) - return self - - def __call__(self, transition: EnvTransition) -> EnvTransition | None: - comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} - messages = comp.get("messages") - if not messages: - # No recipe rendering happened — nothing to do; downstream - # falls back to whatever ``task`` is in the transition. - return transition - - tokenizer = self._get_tokenizer() - - # Pull a sample_idx for the dropout RNG. ``index`` is the - # canonical per-frame key on ``LeRobotDataset`` samples and - # flows through into ``COMPLEMENTARY_DATA`` unchanged. When - # absent (e.g. inference) we fall back to 0 which is harmless - # because the dropout probs are also 0 at inference time. - if _is_batched_messages(messages): - indices_iter = _sample_indices(comp.get("index"), len(messages)) - encoded = [ - self._encode_messages( - tokenizer, - msg, - list(streams), - sorted(int(i) for i in tgt_indices), - sample_idx=int(s_idx) if s_idx is not None else None, - ) - for msg, streams, tgt_indices, s_idx in zip( - messages, - comp.get("message_streams") or [[] for _ in messages], - comp.get("target_message_indices") or [[] for _ in messages], - indices_iter, - strict=False, - ) - ] - else: - sample_idx = _sample_indices(comp.get("index"), 1)[0] - encoded = [ - self._encode_messages( - tokenizer, - messages, - list(comp.get("message_streams") or []), - sorted(int(i) for i in (comp.get("target_message_indices") or [])), - sample_idx=sample_idx, - ) - ] - - # Optional first-N-samples debug dump for sanity-checking what the - # model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES`` - # is set; stops globally after the budget is exhausted. - if _DUMP_BUDGET > 0: - # Stream / target metadata live in parallel arrays in - # COMPLEMENTARY_DATA, not on the message dicts themselves - # (the recipe renderer keeps them separate so the chat - # template doesn't choke on unknown keys). Zip them back - # together for the dumper so each printed message shows - # its actual stream + target flag. - if _is_batched_messages(messages): - msgs_iter = messages - streams_iter = comp.get("message_streams") or [[] for _ in messages] - targets_iter = comp.get("target_message_indices") or [[] for _ in messages] - else: - msgs_iter = [messages] - streams_iter = [list(comp.get("message_streams") or [])] - targets_iter = [list(comp.get("target_message_indices") or [])] - for msg, streams, targets, (ids, labels, predict_action) in zip( - msgs_iter, streams_iter, targets_iter, encoded, strict=False - ): - target_set = {int(i) for i in targets} - annotated_msgs = [] - for i, m in enumerate(msg): - annotated_msgs.append( - { - **m, - "stream": streams[i] if i < len(streams) else None, - "target": True if i in target_set else None, - } - ) - _dump_recipe_sample( - messages=annotated_msgs, - token_ids=ids, - labels=labels, - predict_actions=predict_action, - tokenizer=tokenizer, - ) - - pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - target_length = self.max_length if self.padding == "max_length" else max( - len(ids) for ids, _, _ in encoded - ) - target_length = min(target_length, self.max_length) - - ids_batch = [] - attn_batch = [] - labels_batch = [] - predict_actions = [] - for ids, labels, predict_action in encoded: - ids = ids[:target_length] - labels = labels[:target_length] - attn = [1] * len(ids) - if len(ids) < target_length: - n_pad = target_length - len(ids) - ids = ids + [pad_id] * n_pad - labels = labels + [-100] * n_pad - attn = attn + [0] * n_pad - ids_batch.append(ids) - attn_batch.append(attn) - labels_batch.append(labels) - predict_actions.append(predict_action) - - ids_t = torch.tensor(ids_batch, dtype=torch.long) - attn_t = torch.tensor(attn_batch, dtype=torch.bool) - labels_t = torch.tensor(labels_batch, dtype=torch.long) - predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool) - - if not _is_batched_messages(messages): - ids_t = ids_t.squeeze(0) - attn_t = attn_t.squeeze(0) - labels_t = labels_t.squeeze(0) - predict_actions_t = predict_actions_t.squeeze(0) - - new_complementary = dict(comp) - # Drop the per-recipe sidecar keys; everything downstream needs - # is now in the tokenized form. - new_complementary.pop("messages", None) - new_complementary.pop("message_streams", None) - new_complementary.pop("target_message_indices", None) - # SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` / - # ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place - # them there — and drop ``task`` so the upstream - # ``TokenizerProcessorStep`` (which we replace) doesn't double- - # tokenize. - observation = dict(transition.get(TransitionKey.OBSERVATION) or {}) - observation[OBS_LANGUAGE_TOKENS] = ids_t - observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t - new_complementary["text_labels"] = labels_t - new_complementary["predict_actions"] = predict_actions_t - new_complementary.pop("task", None) - - new_transition = dict(transition) - new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary - new_transition[TransitionKey.OBSERVATION] = observation - return new_transition - - def _encode_messages( - self, - tokenizer: Any, - messages: list[dict[str, Any]], - message_streams: list[str | None], - target_indices: list[int], - sample_idx: int | None = None, - ) -> tuple[list[int], list[int], bool]: - # Apply per-component prompt dropout *before* tokenisation, so - # the dropped messages don't contribute tokens or label-mask - # positions at all. Re-maps ``target_indices`` to account for - # removed messages. - messages, target_indices = self._apply_prompt_dropout( - messages, target_indices, sample_idx - ) - # Flatten ``tool_calls`` into a textual ``...`` marker - # *before* the chat template sees them, so the model is trained - # to emit the same marker the inference parser - # (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``. - messages = [_flatten_say_tool_calls(m) for m in messages] - text_messages = [_strip_lerobot_blocks(m) for m in messages] - - full_ids = tokenizer.apply_chat_template( - text_messages, - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, - ) - full_ids = _as_token_ids(full_ids) - - labels = [-100] * len(full_ids) - for tgt in target_indices: - prefix_ids = tokenizer.apply_chat_template( - text_messages[:tgt], - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, - ) - full_through_target = tokenizer.apply_chat_template( - text_messages[: tgt + 1], - tools=self.tools, - add_generation_prompt=False, - tokenize=True, - return_tensors=None, - ) - prefix_ids = _as_token_ids(prefix_ids) - full_through_target = _as_token_ids(full_through_target) - start = len(prefix_ids) - end = min(len(full_through_target), len(full_ids)) - for pos in range(start, end): - labels[pos] = int(full_ids[pos]) - - # ``predict_actions`` is True iff this sample's recipe declares - # at least one ``low_level`` message — regardless of whether - # it's a target. The ``low_level_execution`` recipe in v2 uses - # ``stream: low_level`` on both user and assistant turns but - # only renders the *user* subtask (no text-CE target on the - # assistant) to avoid trivial "copy previous turn" supervision. - # Scanning targets alone would miss this sample's action loss. - predict_actions = any(s == "low_level" for s in message_streams) - return [int(i) for i in full_ids], labels, predict_actions - - def _apply_prompt_dropout( - self, - messages: list[dict[str, Any]], - target_indices: list[int], - sample_idx: int | None, - ) -> tuple[list[dict[str, Any]], list[int]]: - """Probabilistically drop non-target context messages. - - Heuristic content sniffing — matches the prefix strings that - ``subtask_mem_vqa_speech.yaml``'s recipes use when injecting plan / - memory / subtask / interjection content. Anything else is - kept unchanged. Target messages are never dropped (we still - need their tokens for supervision). - - Returns ``(new_messages, new_target_indices)`` where the - indices are re-mapped to point at the same target turns in - the trimmed list. - """ - probs = { - "plan": float(self.plan_dropout_prob or 0.0), - "memory": float(self.memory_dropout_prob or 0.0), - "subtask": float(self.subtask_dropout_prob or 0.0), - "interjection": float(self.interjection_dropout_prob or 0.0), - } - if not any(p > 0.0 for p in probs.values()): - return messages, target_indices - - # Deterministic per-sample RNG so dropout is reproducible - # across runs (matters for debugging / repro) but varies - # frame-to-frame. - import random # noqa: PLC0415 - - seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0) - rng = random.Random(int(seed_int) & 0xFFFFFFFF) - - target_set = set(target_indices) - keep_flags: list[bool] = [] - for i, msg in enumerate(messages): - if i in target_set: - keep_flags.append(True) - continue - kind = _classify_message_for_dropout(msg) - if kind and rng.random() < probs.get(kind, 0.0): - keep_flags.append(False) - else: - keep_flags.append(True) - - new_messages = [m for m, keep in zip(messages, keep_flags) if keep] - # Re-map target_indices: each old index drops by the count of - # falsy flags before it. - new_target_indices: list[int] = [] - for old_idx in target_indices: - dropped_before = sum(1 for k in keep_flags[:old_idx] if not k) - new_target_indices.append(old_idx - dropped_before) - return new_messages, sorted(new_target_indices) - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - """Pass-through; this step writes runtime tensors not features.""" - return features - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _get_tokenizer(self): # noqa: ANN202 - if self._tokenizer is not None: - return self._tokenizer - try: - from transformers import AutoTokenizer # noqa: PLC0415 - except ImportError as exc: # pragma: no cover - raise ImportError( - "SmolVLA2ChatTokenizerStep requires transformers. " - "`pip install lerobot[transformers-dep]`." - ) from exc - self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) - if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None: - self._tokenizer.pad_token = self._tokenizer.eos_token - return self._tokenizer - - -def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]: - """Remove LeRobot-specific multimodal blocks from ``message`` content. - - The recipe DSL allows authors to write multimodal content like - ``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's - tokenizer doesn't know that ``feature`` key (it expects ``url`` or - ``path``). The actual image tensor flows through SmolVLA's - ``OBS_IMAGES_*`` channels separately; the chat template only needs - the text. So we strip non-text blocks before tokenizing. - """ - new = dict(message) - content = new.get("content") - if isinstance(content, list): - text_parts: list[dict[str, Any]] = [] - for block in content: - if not isinstance(block, dict): - continue - if block.get("type") == "text": - text_parts.append({"type": "text", "text": str(block.get("text", ""))}) - new["content"] = text_parts or [{"type": "text", "text": ""}] - elif content is None: - new["content"] = [{"type": "text", "text": ""}] - else: - new["content"] = [{"type": "text", "text": str(content)}] - if "tool_calls" in new and not new["tool_calls"]: - # Drop empty tool_calls — some templates render them as a - # spurious empty marker. - new.pop("tool_calls") - # ``stream`` and ``target`` were recipe metadata; templates don't - # know them and may warn or crash. - new.pop("stream", None) - new.pop("target", None) - return new - - -def _content_to_text(content: Any) -> str: - """Collapse a message's ``content`` (string or multimodal blocks) to plain text.""" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - t = block.get("text") - if isinstance(t, str): - parts.append(t) - return "\n".join(parts) - return "" - - -def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]: - """Serialize assistant ``say`` tool calls into a textual ``...`` - marker inside the message content (Pi 0.5-style flat tool-call - serialization). - - SmolVLM's chat template would otherwise render ``tool_calls`` as a - structured JSON ```` block, so the LM head learns to emit - JSON — but the inference parser ``_split_plan_and_say`` looks for a - ``...`` marker (``_SAY_RE``). Rewriting the call into the - content text *before* ``apply_chat_template`` aligns the two: the - template only ever tokenizes plain text, and the supervised target - span trains the model to produce the exact marker the runtime reads. - - Messages without ``say`` tool calls are returned unchanged. - """ - tool_calls = message.get("tool_calls") - if not tool_calls: - return message - - say_texts: list[str] = [] - for call in tool_calls: - if not isinstance(call, dict): - continue - fn = call.get("function") or {} - if fn.get("name") != "say": - continue - args = fn.get("arguments") - if isinstance(args, str): - try: - import json # noqa: PLC0415 - - args = json.loads(args) - except (ValueError, TypeError): - args = {} - text = args.get("text", "") if isinstance(args, dict) else "" - if text: - say_texts.append(str(text)) - - if not say_texts: - # No ``say`` calls (or empty text) — drop the structured calls so - # the template doesn't render a stray JSON block, but leave the - # content alone. - new = dict(message) - new.pop("tool_calls", None) - return new - - new = dict(message) - base = _content_to_text(new.get("content")).strip() - marker = "".join(f"{t}" for t in say_texts) - new["content"] = f"{base}\n{marker}" if base else marker - new.pop("tool_calls", None) - return new - - -def _is_batched_messages(messages: Any) -> bool: - return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) - - -def _sample_indices(value: Any, batch_size: int) -> list[int | None]: - if value is None: - return [None] * batch_size - if isinstance(value, torch.Tensor): - if value.numel() == 1: - return [int(value.item())] * batch_size - values = value.reshape(-1).tolist() - return [int(v) for v in values[:batch_size]] - if isinstance(value, (list, tuple)): - if len(value) == 1: - return _sample_indices(value[0], batch_size) - return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]] - return [int(value)] * batch_size - - -def _classify_message_for_dropout(message: dict[str, Any]) -> str | None: - """Best-effort classification of which recipe binding contributed - to this message, used for per-component dropout. - - The canonical recipe authors plan/memory/subtask injections with - distinctive prefix strings in the rendered content. Matching on - those prefixes is brittle if a future recipe author uses - different wording — but it's also localised to one place and - only affects the dropout fraction (never the actual semantics). - Returns ``None`` for messages we don't recognise; those are - always kept. - """ - content = message.get("content") - if isinstance(content, list): - text_parts: list[str] = [] - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - t = block.get("text") - if isinstance(t, str): - text_parts.append(t) - content = "\n".join(text_parts) - if not isinstance(content, str): - return None - head = content.lstrip().lower() - if head.startswith("plan:") or head.startswith("previous plan"): - return "plan" - if head.startswith("memory:") or head.startswith("previous memory"): - return "memory" - if head.startswith("current subtask") or head.startswith("completed subtask"): - return "subtask" - return None - - -def _as_token_ids(value: Any) -> list[int]: - if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()): - value = value["input_ids"] - if hasattr(value, "tolist"): - value = value.tolist() - if isinstance(value, list) and value and isinstance(value[0], list): - value = value[0] - return [int(i) for i in value] - - -# Re-export for tests / introspection -strip_lerobot_blocks = _strip_lerobot_blocks -flatten_say_tool_calls = _flatten_say_tool_calls diff --git a/src/lerobot/policies/smolvla2/configuration_smolvla2.py b/src/lerobot/policies/smolvla2/configuration_smolvla2.py deleted file mode 100644 index 9238dd07e..000000000 --- a/src/lerobot/policies/smolvla2/configuration_smolvla2.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - -from lerobot.configs import PreTrainedConfig - -from ..smolvla.configuration_smolvla import SmolVLAConfig - - -@PreTrainedConfig.register_subclass("smolvla2") -@dataclass -class SmolVLA2Config(SmolVLAConfig): - """SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled. - - SmolVLA strips the LM head from the SmolVLM backbone because it only - needs flow-matching action prediction. SmolVLA2 keeps the LM head so the - same model can train on: - - * **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss - on the action expert, same as SmolVLA. ``predict_actions=True``. - * **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` / - ``user_interjection_response`` / ``high_level_subtask``) — cross- - entropy loss on the LM head over the recipe's target message tokens. - Skips the flow head entirely. ``predict_actions=False``. - * **mixed sub-recipes** — both heads run, losses summed (weighted). - - The split is controlled by ``predict_actions = bool(targets_by_stream - .get("low_level"))`` per the Pi0.5 convention in the steerable - annotation plan (Section I.7), implemented inside the processor / - forward path. Recipes drive it via ``stream`` + ``target`` metadata. - - Compared to ``SmolVLAConfig`` this adds: - - - ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the - train script). When ``None``, SmolVLA2 falls back to the SmolVLA - task-only path so unannotated datasets still work. - - ``text_loss_weight`` / ``flow_loss_weight``: relative weights when - both losses are active in a single sample. - - ``unfreeze_lm_head``: must be ``True`` for the text head to learn — - SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we - need to undo that for SmolVLA2. - - ``train_expert_only=False`` by default, since the VLM body now also - participates in text-target gradients. - """ - - # Recipe / language stack --------------------------------------------- - recipe_path: str | None = "recipes/subtasks_vqa.yaml" - """Path (absolute or relative to ``src/lerobot/configs/``) to a - ``TrainingRecipe`` YAML. The default points at the canonical Hi Robot - blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe - rendering and fall back to SmolVLA's single-task prompt path - (unannotated datasets keep working that way).""" - - apply_chat_template: bool = True - """Apply the SmolVLM tokenizer's chat template to the rendered messages - before tokenizing. SmolVLM's backbone is chat-pretrained, so this - matches its training distribution.""" - - # Loss weights -------------------------------------------------------- - # Pi 0.5 paper §IV.D (Eq. 1) sets α = 10 between the text-CE term - # and the flow-MSE term: L = H(text) + α * ‖ω - a - f_θ‖². The - # rationale is that actions are the primary output and the flow - # head should dominate the gradient signal; text is supervised as - # an auxiliary task and its CE scale (~0.5-2.0 in nats) tends to - # be larger than the flow MSE scale (~0.1-1.0), so without - # up-weighting the action head gets starved. We use a milder - # split (5:1) than the paper's α=10: ~40% of the blend is the - # flow-only ``low_level`` recipe, so the flow term already fires - # often, and α=10 starved the text head into degenerate decoding. - text_loss_weight: float = 1.0 - """Weight on the LM-head cross-entropy term. Set to ``0`` to disable - text training entirely (reverts to flow-only / SmolVLA behaviour).""" - - flow_loss_weight: float = 5.0 - """Weight on the action-expert flow-matching term. Default 5.0 — a - milder split than the Pi 0.5 paper's α=10 (§IV.D), since the - flow-only ``low_level`` recipe already gives the action expert - frequent gradient. Set lower if the text head is underfitting - relative to the action expert; set higher if the action expert is - degrading because text loss dominates.""" - - # Optimizer ----------------------------------------------------------- - optimizer_lr: float = 2.5e-5 - """Peak learning rate. Overrides ``SmolVLAConfig``'s ``1e-4``. - - SmolVLA can afford ``1e-4`` because it *freezes* the language head — - only the from-scratch action expert sees that LR. SmolVLA2 unfreezes - ``lm_head`` + the last text layer and fine-tunes the **pretrained** - SmolVLM2 language weights, and ``1e-4`` is too aggressive for a - pretrained LM: it destabilises the language representations and - collapses generation into degenerate repetition. ``2.5e-5`` matches - pi05's peak LR (openpi ``CosineDecaySchedule``), the comparable - text-co-trained policy. The action expert trains slightly slower at - this LR, so budget more steps.""" - - # Backbone training --------------------------------------------------- - unfreeze_lm_head: bool = True - """Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately - preceding norm + last text-model layer that SmolVLA freezes). Must be - ``True`` for the text head to learn. Setting this to ``False`` - effectively reduces SmolVLA2 back to SmolVLA's flow-only training, - which is occasionally useful for ablations.""" - - load_vlm_weights: bool = True - """Load the pretrained SmolVLM2 backbone weights (vision tower + - language model + ``lm_head``) instead of random-initialising them. - - ``SmolVLAConfig`` defaults this to ``False`` because the original - SmolVLA pre-training run trained the VLM body itself. For SmolVLA2 - that default is a footgun: the text head **is** the SmolVLM2 - ``lm_head``, and the high-level subtask supervision is hopeless if - it starts from a random language model — it can only memorise. - SmolVLA2 therefore defaults this to ``True`` so every run fine-tunes - from the pretrained ``vlm_model_name`` checkpoint - (``HuggingFaceTB/SmolVLM2-500M-Video-Instruct``). - - Note this loads the *VLM backbone* pretrained; the action expert - still trains from scratch on the robot data (standard SmolVLA - fine-tuning). To also start the action expert from pretrained - weights, fine-tune from a full ``lerobot/smolvla_base`` checkpoint - via ``--policy.path``.""" - - # Per-component prompt dropout (Pi0.7 §V.E) --------------------------- - # At training, randomly drop non-target context messages whose - # content was substituted from the named recipe binding. Forces - # the model to handle missing context — directly attacks the - # memorisation collapse where a stale or missing plan/memory at - # inference puts the prompt out-of-distribution and the LM head - # falls back to dominant-mode fragments. All default to 0.0 so - # behaviour is identical until explicitly enabled. - plan_dropout_prob: float = 0.0 - """Drop messages whose content starts with ``Plan:`` or ``Previous plan`` - with this probability per sample.""" - memory_dropout_prob: float = 0.0 - """Drop messages whose content starts with ``Memory:`` or ``Previous memory`` - with this probability per sample.""" - subtask_dropout_prob: float = 0.0 - """Drop messages whose content starts with ``Current subtask`` or - ``Completed subtask`` with this probability per sample.""" - - def __post_init__(self) -> None: - super().__post_init__() - # Backbone needs gradients flowing through its text path when the - # LM head is producing supervised text. Override the SmolVLA - # default (`train_expert_only=True`) unless the user explicitly - # opts out of text training via `text_loss_weight=0`. - if self.text_loss_weight > 0 and self.unfreeze_lm_head: - # The user can still flip this back via CLI; this only - # changes the *default* when SmolVLA2 is actually training a - # text head. - self.train_expert_only = False diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py deleted file mode 100644 index 08660aa00..000000000 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ /dev/null @@ -1,692 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy. - -Adds: - -* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised, -* a forward path that runs the flow head, the text head, or both, - driven by ``batch["predict_actions"]`` and ``batch["text_labels"]`` - produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on - this branch). - -Per-sample routing — within one batch: - -* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow - loss (action chunk supervision). -* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the - flow loss; only its text tokens (where ``text_labels[i, t] != -100``) - contribute to the LM-head cross-entropy. - -Falls back to ``SmolVLAPolicy.forward`` cleanly when neither -``text_labels`` nor ``predict_actions`` is in the batch — unannotated -datasets keep working unchanged. -""" - -from __future__ import annotations - -import math -from typing import Any - -import torch -import torch.nn.functional as F -from torch import Tensor - -from lerobot.utils.constants import ( - ACTION, - OBS_LANGUAGE_ATTENTION_MASK, - OBS_LANGUAGE_TOKENS, - OBS_STATE, -) - -from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks -from .configuration_smolvla2 import SmolVLA2Config - - -def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, int]: - """Find ``[lang_start, lang_end)`` inside the SmolVLA prefix. - - ``embed_prefix`` lays out the prefix as - ``[image_blocks..., lang, state, padding]`` with the att-mask - convention ``[0]*image, [0]*lang, [1]*state, [0]*padding`` (see - ``modeling_smolvla.SmolVLAModel.embed_prefix``). State is exactly - one token, and it's the *only* position with ``att_mask == 1``, - so we use the first ``1`` to anchor lang_end. Computing it this - way is robust to (a) state being projected to one embedding token - regardless of its raw feature dim, and (b) the trailing padding - added when ``seq_len < prefix_length``. - """ - row = prefix_att_masks[0] - ones = row.nonzero(as_tuple=False) - if ones.numel() == 0: - raise RuntimeError( - "SmolVLA2: state token not found in prefix att_masks — " - "can't locate language range." - ) - state_start = int(ones[0, 0].item()) - lang_end = state_start - lang_start = lang_end - num_lang - if lang_start < 0: - raise RuntimeError( - f"SmolVLA2: lang range underflows prefix " - f"(state_start={state_start}, num_lang={num_lang})." - ) - return lang_start, lang_end - - -def _mark_target_span_causal( - prefix_att_masks: Tensor, text_labels: Tensor, lang_start: int, lang_end: int -) -> Tensor: - """Make the supervised text-target span causally masked. - - ``embed_prefix`` flags every language token with ``att=0``, which - ``make_att_2d_masks`` turns into one fully *bidirectional* block — - so a target token's hidden state attends to the very tokens it is - supposed to predict. The text cross-entropy then degenerates into - a copy task (loss → ~0) and the model never learns causal - next-token prediction — at inference, where ``select_message`` - decodes autoregressively (causally), it collapses. - - Fix: set ``att=1`` on the language positions that are supervised - targets (``text_labels != -100``). With ``make_att_2d_masks``'s - cumulative-block rule each target token then attends to images + - the user prompt bidirectionally and to *earlier* target tokens - only — i.e. genuine causal next-token prediction, matching - inference. Non-target language (the user prompt, and the - ``low_level_execution`` subtask which is a user turn, not a - target) stays ``att=0`` / bidirectional. The action expert is - unaffected: the suffix has a strictly higher cumsum so it still - attends to every prefix token. - """ - att = prefix_att_masks.clone() - n = min(text_labels.shape[1], lang_end - lang_start) - if n <= 0: - return att - target = text_labels[:, :n] != -100 # (B, n) bool - seg = att[:, lang_start : lang_start + n].bool() - att[:, lang_start : lang_start + n] = seg | target - return att - - -def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor: - """Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.""" - num_lang = logits.shape[1] - if text_labels.shape[1] != num_lang: - common = min(text_labels.shape[1], num_lang) - logits = logits[:, :common] - text_labels = text_labels[:, :common] - shift_logits = logits[:, :-1, :].contiguous() - shift_labels = text_labels[:, 1:].contiguous().long() - valid = shift_labels != -100 - loss = F.cross_entropy( - shift_logits.reshape(-1, shift_logits.shape[-1]), - shift_labels.reshape(-1), - ignore_index=-100, - reduction="sum", - ) - return loss / valid.sum().clamp(min=1) - - -class SmolVLA2Policy(SmolVLAPolicy): - """SmolVLA + re-enabled SmolVLM language head.""" - - config_class = SmolVLA2Config - name = "smolvla2" - - def __init__(self, config: SmolVLA2Config, **kwargs): - if not isinstance(config, SmolVLA2Config): - config = SmolVLA2Config( - **{ - f.name: getattr(config, f.name) - for f in config.__dataclass_fields__.values() - if hasattr(config, f.name) - } - ) - super().__init__(config, **kwargs) - if config.unfreeze_lm_head and config.text_loss_weight > 0: - self._unfreeze_lm_head() - - # ------------------------------------------------------------------ - # Backbone surgery - # ------------------------------------------------------------------ - - def _unfreeze_lm_head(self) -> None: - """Re-enable gradients on the text-output path so the LM head - loss can flow back. - - SmolVLA's ``set_requires_grad`` freezes three things when - ``train_expert_only=False``: ``lm_head``, - ``text_model.model.norm.weight``, and the last 1-2 text-model - transformer layers (see ``smolvlm_with_expert.py:167-176``). - We must unfreeze *all three* — otherwise gradients still die - in the frozen final block and the lm_head learns nothing. - """ - vlm_with_expert = getattr(self.model, "vlm_with_expert", None) - if vlm_with_expert is None: - return - vlm = getattr(vlm_with_expert, "vlm", None) - if vlm is None: - return - - # Mirror the freeze targets from ``smolvlm_with_expert.set_requires_grad``. - num_vlm = getattr(vlm_with_expert, "num_vlm_layers", None) - num_expert = getattr(vlm_with_expert, "num_expert_layers", None) - last_layers = [] - if num_vlm is not None: - last_layers.append(num_vlm - 1) - if ( - num_expert is not None - and num_vlm != num_expert - and num_vlm % num_expert == 0 - ): - last_layers.append(num_vlm - 2) - unfreeze_prefixes = [ - "lm_head", - "text_model.model.norm.weight", - *[f"text_model.model.layers.{layer}." for layer in last_layers], - ] - for name, param in vlm.named_parameters(): - if any(k in name for k in unfreeze_prefixes): - param.requires_grad = True - - # ------------------------------------------------------------------ - # Forward - # ------------------------------------------------------------------ - - def forward( - self, - batch: dict[str, Tensor], - noise: Tensor | None = None, - time: Tensor | None = None, - reduction: str = "mean", - ) -> tuple[Tensor, dict[str, Any]]: - """Forward pass with optional dual-head loss. - - Two routing knobs from the batch (produced by - :class:`SmolVLA2ChatTokenizerStep`): - - * ``text_labels`` — per-token labels with ``-100`` for non-target - positions. Triggers the text-loss path through ``lm_head``. - * ``predict_actions`` — per-sample bool tensor. ``True`` ⇒ - include this sample's action chunk in the flow loss. - - When neither is present, delegate to ``SmolVLAPolicy.forward``. - """ - text_labels = batch.get("text_labels") - predict_actions_t = batch.get("predict_actions") - - has_text_data = ( - text_labels is not None - and isinstance(text_labels, Tensor) - and self.config.text_loss_weight > 0 - ) - has_per_sample_routing = ( - predict_actions_t is not None and isinstance(predict_actions_t, Tensor) - ) - - if not has_text_data and not has_per_sample_routing: - return super().forward(batch, noise=noise, time=time, reduction=reduction) - - loss_dict: dict[str, Any] = {} - device = batch[OBS_STATE].device - total = torch.zeros((), device=device, dtype=torch.float32) - - run_flow = ( - self.config.flow_loss_weight > 0 - and ACTION in batch - and (not has_per_sample_routing or bool(predict_actions_t.any().item())) - ) - - # ------------------------------------------------------------ - # Fused path — one backbone forward for flow + text together. - # ------------------------------------------------------------ - if run_flow and has_text_data: - flow_loss, text_loss, flow_diag = self._compute_fused_loss( - batch, text_labels, predict_actions_t, noise=noise, time=time - ) - total = total + self.config.flow_loss_weight * flow_loss - total = total + self.config.text_loss_weight * text_loss - loss_dict["flow_loss"] = float(flow_loss.detach().item()) - loss_dict["text_loss"] = float(text_loss.detach().item()) - for k, v in flow_diag.items(): - loss_dict[f"flow_{k}"] = v - elif run_flow: - per_sample_flow, flow_diag = super().forward( - batch, noise=noise, time=time, reduction="none" - ) - if has_per_sample_routing: - mask = predict_actions_t.to(per_sample_flow.dtype) - masked = per_sample_flow * mask - denom = mask.sum().clamp(min=1.0) - flow_loss = masked.sum() / denom - else: - flow_loss = per_sample_flow.mean() - total = total + self.config.flow_loss_weight * flow_loss - loss_dict["flow_loss"] = float(flow_loss.detach().item()) - for k, v in flow_diag.items(): - loss_dict[f"flow_{k}"] = v - elif has_text_data: - text_loss = self._compute_text_loss(batch, text_labels) - total = total + self.config.text_loss_weight * text_loss - loss_dict["text_loss"] = float(text_loss.detach().item()) - else: - # No path fired — happens when both loss weights are 0 or - # the batch has neither action samples nor supervised text. - # Fail loud rather than train silently on a zero loss. - raise RuntimeError( - "SmolVLA2Policy.forward: nothing to train — " - "flow_loss_weight=%s, text_loss_weight=%s, " - "predict_actions.any()=%s, has_text_data=%s" - % ( - self.config.flow_loss_weight, - self.config.text_loss_weight, - bool(predict_actions_t.any().item()) if has_per_sample_routing else None, - has_text_data, - ) - ) - - loss_dict["loss"] = float(total.detach().item()) - - if reduction == "none": - # Per-sample loss isn't meaningfully defined for the dual - # path; broadcast the scalar to (B,) for caller compat. - return total.expand(batch[OBS_STATE].shape[0]), loss_dict - return total, loss_dict - - # ------------------------------------------------------------------ - # Text-loss internals - # ------------------------------------------------------------------ - - def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: - """Cross-entropy on the SmolVLM ``lm_head`` over target tokens.""" - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[OBS_LANGUAGE_TOKENS] - lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - # Causally mask the supervised target span so the text-CE is - # genuine next-token prediction (see ``_mark_target_span_causal``). - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) - prefix_att_masks = _mark_target_span_causal( - prefix_att_masks, text_labels, lang_start, lang_end - ) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - - # Prefix-only forward. - out_pair, _ = self.model.vlm_with_expert.forward( - attention_mask=prefix_att_2d_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=False, - fill_kv_cache=True, - ) - prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair - if prefix_out is None: - raise RuntimeError( - "SmolVLA2: vlm_with_expert.forward returned no prefix hidden " - "states — text-loss path needs them." - ) - - # ``lang_start`` / ``lang_end`` were located above on the - # *unmodified* att masks — don't recompute here, because - # ``_mark_target_span_causal`` set target lang tokens to 1 and - # ``_locate_lang_range`` keys on the first 1 (the state token). - vlm = self.model.vlm_with_expert.vlm - lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) - logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab) - - return _shifted_ce(logits, text_labels) - - # ------------------------------------------------------------------ - # Fused flow + text loss (single backbone forward) - # ------------------------------------------------------------------ - - def _compute_fused_loss( - self, - batch: dict[str, Tensor], - text_labels: Tensor, - predict_actions_t: Tensor | None, - noise: Tensor | None = None, - time: Tensor | None = None, - ) -> tuple[Tensor, Tensor, dict[str, Any]]: - """One backbone forward → both flow MSE and text CE. - - Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one - ``vlm_with_expert`` call) but captures **both** outputs: - - * ``prefix_out[:, lang_start:lang_end]`` → ``lm_head`` → CE on - ``text_labels`` (same slicing as ``_compute_text_loss``). - * ``suffix_out[:, -chunk_size:]`` → ``action_out_proj`` → flow - MSE against ``noise - actions`` (same as the parent forward). - - Saves one backbone pass per training step vs. running the flow - and text paths separately — same trick PI052Policy uses in - ``_compute_all_losses_fused``. - """ - cfg = self.config - if cfg.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[OBS_LANGUAGE_TOKENS] - lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - actions = self.prepare_action(batch) - - inner = self.model - if noise is None: - noise = inner.sample_noise(actions.shape, actions.device) - if time is None: - time = inner.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - - prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - # Causally mask the supervised text-target span (see - # ``_mark_target_span_causal``). Per-sample: high_level_subtask - # samples have a subtask target → causal; low_level_execution - # samples have all -100 labels → untouched / bidirectional, so - # the action expert still reads the subtask as bidirectional - # context. ``lang_start`` / ``lang_end`` located here on the - # unmodified mask and reused for the text-loss slice below. - lang_start, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) - prefix_att_masks = _mark_target_span_causal( - prefix_att_masks, text_labels, lang_start, lang_end - ) - suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - - out_pair, _ = inner.vlm_with_expert.forward( - attention_mask=att_2d_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, - ) - prefix_out, suffix_out = out_pair[0], out_pair[1] - if prefix_out is None or suffix_out is None: - raise RuntimeError( - "SmolVLA2: fused forward expected both prefix and suffix " - "hidden states from vlm_with_expert." - ) - - # ---------------- flow loss (per-sample maskable) ---------------- - chunk = cfg.chunk_size - suffix_chunk = suffix_out[:, -chunk:].to(torch.float32) - v_t = inner.action_out_proj(suffix_chunk) - losses = F.mse_loss(u_t, v_t, reduction="none") - - original_action_dim = cfg.action_feature.shape[0] - losses = losses[:, :, :original_action_dim] - flow_diag = {"losses_after_forward": float(losses.detach().mean().item())} - - actions_is_pad = batch.get("action_is_pad") - if actions_is_pad is not None: - in_episode = ~actions_is_pad - losses = losses * in_episode.unsqueeze(-1) - flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item()) - - losses = losses[:, :, : cfg.max_action_dim] - flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item()) - - per_sample_flow = losses.mean(dim=(1, 2)) - if predict_actions_t is not None: - mask = predict_actions_t.to(per_sample_flow.dtype) - flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0) - else: - flow_loss = per_sample_flow.mean() - - # ---------------- text loss (lang slice of prefix) --------------- - # ``lang_start`` / ``lang_end`` from above (unmodified mask). - vlm = inner.vlm_with_expert.vlm - lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) - logits = vlm.lm_head(lang_hidden) - text_loss = _shifted_ce(logits, text_labels) - - return flow_loss, text_loss, flow_diag - - # ------------------------------------------------------------------ - # Inference: text generation - # ------------------------------------------------------------------ - - @torch.no_grad() - def select_message( - self, - batch: dict[str, Tensor], - *, - max_new_tokens: int = 256, - min_new_tokens: int = 0, - eos_token_id: int | None = None, - temperature: float = 0.0, - top_p: float = 1.0, - tokenizer: Any = None, - ) -> str: - """Generate text continuation from the chat-templated prompt. - - AR decoding with KV caching reused from SmolVLA's inference - path. Batch size is assumed to be 1 (the runtime calls this - per-event). Returns the decoded string of new tokens (the - prompt itself is not included). - - Parameters - ---------- - batch: - Already through the SmolVLA2 preprocessor — expects - ``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``, - ``OBS_LANGUAGE_ATTENTION_MASK``. - max_new_tokens: - Hard cap on generated tokens; stops earlier on EOS. - eos_token_id: - Override the tokenizer's EOS. ``None`` ⇒ use the - tokenizer's default. - temperature, top_p: - ``temperature=0`` does greedy argmax (default — matches - training distribution most closely). Set ``temperature>0`` - with optional ``top_p<1`` for nucleus sampling. - tokenizer: - Optional pre-loaded tokenizer to avoid the cold-start - ``AutoTokenizer.from_pretrained`` round-trip on every call. - """ - self.eval() - - if tokenizer is None: - from transformers import AutoTokenizer # noqa: PLC0415 - - tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name) - if eos_token_id is None: - eos_token_id = tokenizer.eos_token_id - - # Build the full set of special-token ids to suppress during - # the ``min_new_tokens`` window. EOS alone is not enough on a - # memorised SmolVLM head — when EOS is masked, the argmax - # falls onto a sibling special token (``<|im_end|>``, - # ````, ````, ````, - # …) which then survives generation but gets stripped by - # ``skip_special_tokens=True`` so ``decode`` returns an empty - # string and the runtime sees ``last_raw='(empty)'`` every - # chunk boundary. - special_ids_set: set[int] = set() - try: - for sid in (tokenizer.all_special_ids or []): - if sid is not None: - special_ids_set.add(int(sid)) - except Exception: # noqa: BLE001 - pass - if eos_token_id is not None: - special_ids_set.add(int(eos_token_id)) - - # Match training's text-loss forward path (see - # ``_compute_text_loss`` above): build the full prefix via - # ``embed_prefix`` so images + state conditioning is intact, - # then loop AR with ``fill_kv_cache=True, use_cache=False``. - # That flag combo routes every layer through - # ``forward_attn_layer`` (which gracefully skips ``None`` - # expert inputs via ``if hidden_states is None or layer is - # None: continue``) and short-circuits the cache-update logic - # so we don't have to manage past_kv. Each step just - # re-forwards the cumulative ``[prefix + generated]`` - # sequence. - # - # This is O(n²) in generated text length but cheap in - # absolute terms: image encoding happens once via the initial - # ``embed_prefix`` call, and the per-step cost is just one - # SmolVLM transformer pass over a sequence that grows by one - # token each time. Standard SmolVLM ``generate`` was the - # other tempting path, but it can't accept SmolVLA's custom - # ``state_proj`` output and its tile-grid expectations - # disagree with our preprocessor — both lead to garbage - # generation, which is what the prior approach produced. - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[OBS_LANGUAGE_TOKENS] - lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - - # ``embed_prefix`` lays the prefix out as ``[images, lang, state]`` - # — the state token is LAST. Training supervises the text head on - # the *language* positions (see ``_compute_text_loss`` / - # ``_compute_fused_loss``: lm_head over ``prefix_out[lang_start: - # lang_end]``). So AR text generation must continue from the last - # language token (right after the ``Assistant:`` generation - # prompt) — NOT from the state token, whose hidden state exists - # for the action expert to read and which the lm_head was never - # trained to decode subtask text from. Truncating the state token - # here makes ``prefix_out[:, -1:]`` in the loop below the last - # language position, matching the training distribution. - _, lang_end = _locate_lang_range(prefix_att_masks, lang_tokens.shape[1]) - prefix_embs = prefix_embs[:, :lang_end] - prefix_pad_masks = prefix_pad_masks[:, :lang_end] - prefix_att_masks = prefix_att_masks[:, :lang_end] - - device = prefix_embs.device - bsize = prefix_embs.shape[0] - vlm = self.model.vlm_with_expert.vlm - emb_dim = prefix_embs.shape[-1] - text_emb_scale = math.sqrt(emb_dim) - - current_embs = prefix_embs - current_pad = prefix_pad_masks - current_att = prefix_att_masks - ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device) - - generated: list[int] = [] - for _ in range(max_new_tokens): - full_2d = make_att_2d_masks(current_pad, current_att) - full_pos = torch.cumsum(current_pad, dim=1) - 1 - - out_pair, _ = self.model.vlm_with_expert.forward( - attention_mask=full_2d, - position_ids=full_pos, - past_key_values=None, - inputs_embeds=[current_embs, None], - use_cache=False, - fill_kv_cache=True, - ) - prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair - if prefix_out is None: - raise RuntimeError( - "select_message: vlm_with_expert.forward returned no hidden states." - ) - - last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype) - logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V) - # Suppress *all* special tokens until we've decoded - # ``min_new_tokens`` real (renderable) tokens. Without - # this, a memorised SmolVLM head whose argmax at position - # 0 is a special token produces an empty completion every - # time — either EOS directly, or (after we mask EOS) the - # argmax shifts to a sibling special id (``<|im_end|>``, - # ````, ````, …) which decode strips - # via ``skip_special_tokens=True``. Masking the full - # ``all_special_ids`` set for the first N steps forces - # the head to commit to a normal vocabulary token before - # it can close (or quietly poison) the turn. - if special_ids_set and len(generated) < min_new_tokens: - for sid in special_ids_set: - logits_step[..., sid] = float("-inf") - next_ids = self._sample_next_token(logits_step, temperature, top_p) - tok_id = int(next_ids[0].item()) - generated.append(tok_id) - if eos_token_id is not None and tok_id == eos_token_id: - break - - new_emb = self.model.vlm_with_expert.embed_language_tokens( - next_ids.unsqueeze(0) - ) - new_emb = new_emb * text_emb_scale - current_embs = torch.cat([current_embs, new_emb], dim=1) - current_pad = torch.cat([current_pad, ones_step], dim=1) - current_att = torch.cat([current_att, ones_step], dim=1) - - decoded = tokenizer.decode(generated, skip_special_tokens=True).strip() - # When the visible decoded string is empty but tokens *were* - # generated, expose what those raw tokens decoded to without - # the special-token filter. This is what the runtime turns - # into a scrollback line when ``last_raw='(empty)'`` so the - # operator can tell whether the head is emitting EOS, image - # placeholder tokens, the chat-template ``<|im_end|>`` shard, - # or something else. - if not decoded and generated: - try: - self._last_select_message_debug = ( - f"raw_ids={generated[:16]} " - f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}" - ) - except Exception: # noqa: BLE001 - self._last_select_message_debug = f"raw_ids={generated[:16]}" - else: - self._last_select_message_debug = "" - return decoded - - @staticmethod - def _sample_next_token( - logits: Tensor, temperature: float, top_p: float - ) -> Tensor: - """Pick one token id per batch row from ``logits``.""" - if temperature <= 0.0: - return logits.argmax(dim=-1) - scaled = logits / max(temperature, 1e-6) - probs = F.softmax(scaled, dim=-1) - if top_p < 1.0: - sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True) - cum = sorted_probs.cumsum(dim=-1) - mask = cum > top_p - # Always keep the most-likely token. - mask[..., 0] = False - sorted_probs = sorted_probs.masked_fill(mask, 0.0) - sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9) - pick = torch.multinomial(sorted_probs, num_samples=1) - return sorted_idx.gather(-1, pick).squeeze(-1) - return torch.multinomial(probs, num_samples=1).squeeze(-1) diff --git a/src/lerobot/policies/smolvla2/processor_smolvla2.py b/src/lerobot/policies/smolvla2/processor_smolvla2.py deleted file mode 100644 index 9d0913e0b..000000000 --- a/src/lerobot/policies/smolvla2/processor_smolvla2.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""SmolVLA2 processor pipelines. - -When ``config.recipe_path`` is set, the pre-processor pipeline becomes: - - rename observations - add batch dim - RenderMessagesStep(recipe) # PR 1: language_* → messages - SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions - DeviceProcessorStep - NormalizerProcessorStep - -When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's -plain task-string pipeline so unannotated datasets still work. - -Post-processor is unchanged from SmolVLA. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -import torch - -from lerobot.configs.recipe import TrainingRecipe -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - NormalizerProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - RenameObservationsProcessorStep, - RenderMessagesStep, - UnnormalizerProcessorStep, - policy_action_to_transition, - transition_to_policy_action, -) -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME - -from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors -from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep -from .configuration_smolvla2 import SmolVLA2Config - - -def make_smolvla2_pre_post_processors( - config: SmolVLA2Config, - dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, -) -> tuple[ - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - """Build SmolVLA2's pre/post-processor pipelines. - - With ``recipe_path`` set, inserts the recipe-rendering step and the - chat-template tokenizer that emits ``text_labels`` and - ``predict_actions`` for the dual-loss path. Without it, falls back - to SmolVLA's plain task-string pipeline so unannotated datasets - keep working unchanged. - """ - if not config.recipe_path: - return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats) - - recipe = _load_recipe(config.recipe_path) - - input_steps = [ - RenameObservationsProcessorStep(rename_map={}), - AddBatchDimensionProcessorStep(), - RenderMessagesStep(recipe=recipe), - SmolVLA2ChatTokenizerStep( - tokenizer_name=config.vlm_model_name, - max_length=config.tokenizer_max_length, - padding=config.pad_language_to, - plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0), - memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0), - subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0), - ), - DeviceProcessorStep(device=config.device), - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - ] - output_steps = [ - UnnormalizerProcessorStep( - features=config.output_features, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - DeviceProcessorStep(device="cpu"), - ] - return ( - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, - name=POLICY_PREPROCESSOR_DEFAULT_NAME, - ), - PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, - name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - to_transition=policy_action_to_transition, - to_output=transition_to_policy_action, - ), - ) - - -def _load_recipe(path_str: str) -> TrainingRecipe: - """Resolve ``path_str`` to a ``TrainingRecipe``. - - Accepts an absolute path or a path relative to - ``src/lerobot/configs/`` so recipe authors can write - ``--policy.recipe_path=recipes/subtasks_vqa.yaml``. - """ - p = Path(path_str) - if not p.is_absolute() and not p.exists(): - from lerobot.configs import recipe as _recipe_module # noqa: PLC0415 - - configs_dir = Path(_recipe_module.__file__).resolve().parent - candidate = configs_dir / path_str - if candidate.exists(): - p = candidate - return TrainingRecipe.from_yaml(p) diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_pi052_runtime.py similarity index 95% rename from src/lerobot/scripts/lerobot_smolvla2_runtime.py rename to src/lerobot/scripts/lerobot_pi052_runtime.py index 04f327ec0..032136e03 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_pi052_runtime.py @@ -12,10 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2. +"""``lerobot-pi052-runtime`` — interactive REPL for trained PI052. Drives the multi-rate runtime defined in -:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user +:mod:`lerobot.policies.pi052.inference`. Stdin becomes the user channel: type a task, then natural-language interjections / questions. The runtime prints state changes (plan / subtask / memory / vqa / speech) as they happen. @@ -26,16 +26,16 @@ Examples Dry run on a Hub checkpoint, no robot connected — useful for sanity- checking text generation:: - uv run lerobot-smolvla2-runtime \\ - --policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\ + uv run lerobot-pi052-runtime \\ + --policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\ --no_robot \\ --task="please clean the kitchen" Same, but feed real frames from an annotated dataset so plan / subtask / memory / VQA generation runs against actual video + state:: - uv run lerobot-smolvla2-runtime \\ - --policy.path=pepijn223/smolvla2_hirobot_super_poulain_tool2 \\ + uv run lerobot-pi052-runtime \\ + --policy.path=pepijn223/pi052_hirobot_super_poulain_tool2 \\ --dataset.repo_id=pepijn223/super_poulain_annotated \\ --dataset.episode=0 \\ --no_robot \\ @@ -43,7 +43,7 @@ Same, but feed real frames from an annotated dataset so plan / subtask With a real robot:: - uv run lerobot-smolvla2-runtime \\ + uv run lerobot-pi052-runtime \\ --policy.path=... \\ --robot.type=so101 --robot.port=/dev/tty.usbmodem... \\ --tts.voice=alba @@ -62,18 +62,14 @@ import logging import sys from typing import Any, Callable -logger = logging.getLogger("lerobot.smolvla2.runtime") +logger = logging.getLogger("lerobot.pi052.runtime") def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: p = argparse.ArgumentParser( - # prog defaults to the invoked command name, so this reads - # correctly whether run as lerobot-smolvla2-runtime or - # lerobot-pi052-runtime. description=( - "Interactive REPL runtime for a trained hierarchical VLA " - "checkpoint (SmolVLA2 or PI052 — policy type is read from " - "the checkpoint)." + "Interactive REPL runtime for a trained PI052 hierarchical " + "VLA checkpoint." ), ) p.add_argument( @@ -82,7 +78,7 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: type=str, required=True, help=( - "Local directory or Hugging Face Hub repo id pointing at a trained SmolVLA2 ``pretrained_model``." + "Local directory or Hugging Face Hub repo id pointing at a trained PI052 ``pretrained_model``." ), ) p.add_argument( @@ -368,7 +364,7 @@ def _load_policy_and_preprocessor( policy_path: str, dataset_repo_id: str | None, ) -> tuple[Any, Any, Any, Any]: - """Load a SmolVLA2 checkpoint (local path or Hub repo id). + """Load a PI052 checkpoint (local path or Hub repo id). Returns ``(policy, preprocessor, postprocessor, ds_meta)``. ``preprocessor`` / ``postprocessor`` / ``ds_meta`` are ``None`` @@ -430,7 +426,7 @@ def _build_observation_provider( The dataset's ``language_persistent`` / ``language_events`` columns are stripped before the sample reaches the preprocessor, - so ``RenderMessagesStep`` and ``SmolVLA2ChatTokenizerStep`` are + so ``RenderMessagesStep`` and ``PI052TextTokenizerStep`` are no-ops; the runtime supplies its own messages from current state. """ import torch # noqa: PLC0415 @@ -613,7 +609,7 @@ def _select_task_interactively( # bootstrap default (may be None — REPL handles that). return bootstrap_task - print("\n[smolvla2] Select startup task:", flush=True) + print("\n[pi052] Select startup task:", flush=True) if options: for i, opt in enumerate(options, 1): marker = " (dataset default)" if opt == bootstrap_task else "" @@ -816,7 +812,7 @@ def _build_robot_observation_provider( # ``ds_features``. The training distribution sees frames at the # recorded resolution (e.g. 480×640); a live Mac/USB camera will # almost always hand us a different native size (720p / 1080p). - # SmolVLA's internal ``resize_with_pad(512, 512)`` does pad the + # PI052's internal ``resize_with_pad(512, 512)`` does pad the # input to a fixed canvas, but the *geometry* of that pad differs # by input aspect ratio — top/left padding varies, so the visual # tokens at each tile carry different content than what the model @@ -1017,7 +1013,7 @@ def _build_robot_action_executor( def _print_runtime_help() -> None: """Print the slash-command reference.""" print( - "[smolvla2] commands (arguments need no quotes):\n" + "[pi052] commands (arguments need no quotes):\n" " /action run the robot; an argument switches to that task\n" " /action resume the robot on the current task\n" " /action run the robot for N seconds, then auto-pause\n" @@ -1085,7 +1081,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool: secs = float(rest) runtime.state["action_deadline"] = _time.monotonic() + secs print( - f"[smolvla2] action — running {secs:g}s, then auto-pause", + f"[pi052] action — running {secs:g}s, then auto-pause", flush=True, ) else: @@ -1095,16 +1091,16 @@ def _handle_slash_command(runtime: Any, line: str) -> bool: # New task → drop the stale subtask so the high-level # loop regenerates one for the new goal. runtime.state["current_subtask"] = None - print(f"[smolvla2] action — task: {rest!r}", flush=True) + print(f"[pi052] action — task: {rest!r}", flush=True) elif runtime.state.get("task"): print( - f"[smolvla2] action — resuming: {runtime.state['task']!r}", + f"[pi052] action — resuming: {runtime.state['task']!r}", flush=True, ) else: runtime.state["mode"] = "paused" print( - "[smolvla2] no task set — use /action ", + "[pi052] no task set — use /action ", flush=True, ) return True @@ -1113,7 +1109,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool: runtime.state["mode"] = "paused" runtime.state["action_deadline"] = None _clear_action_queue(runtime) - print("[smolvla2] paused — robot holding position", flush=True) + print("[pi052] paused — robot holding position", flush=True) return True if cmd in {"/question", "/q", "/ask", "/vqa", "/vlm"}: @@ -1124,7 +1120,7 @@ def _handle_slash_command(runtime: Any, line: str) -> bool: _clear_action_queue(runtime) if not rest: print( - "[smolvla2] usage: /question " + "[pi052] usage: /question " "(e.g. /question point to the yellow cube)", flush=True, ) @@ -1144,7 +1140,7 @@ def _run_vqa_query(runtime: Any, question: str) -> None: Invoked by ``/question`` — the action loop is paused first so the policy is free for a synchronous VQA call. """ - from lerobot.policies.smolvla2.inference.vqa import handle_vqa_query # noqa: PLC0415 + from lerobot.policies.pi052.inference.vqa import handle_vqa_query # noqa: PLC0415 handle_vqa_query( policy=runtime.policy, @@ -1180,11 +1176,11 @@ def _run_autonomous( if not auto_start and runtime.state.get("mode", "paused") == "action": try: input( - "[smolvla2] Robot connected — starting in ACTION mode. " + "[pi052] Robot connected — starting in ACTION mode. " "Press ENTER to begin, Ctrl+C to abort. " ) except (EOFError, KeyboardInterrupt): - print("\n[smolvla2] aborted before start", flush=True) + print("\n[pi052] aborted before start", flush=True) return 130 if initial_task: @@ -1193,7 +1189,7 @@ def _run_autonomous( thread = threading.Thread( target=runtime.run, kwargs={"max_ticks": max_ticks}, - name="smolvla2-runtime-loop", + name="pi052-runtime-loop", daemon=True, ) thread.start() @@ -1251,7 +1247,7 @@ def _run_autonomous( if hasattr(queue, "clear"): queue.clear() print( - "\n[smolvla2] timed action elapsed — paused", + "\n[pi052] timed action elapsed — paused", flush=True, ) else: @@ -1264,7 +1260,7 @@ def _run_autonomous( pass _panel_stop.wait(0.7) - panel_thread = threading.Thread(target=_panel_loop, name="smolvla2-panel-redraw", daemon=True) + panel_thread = threading.Thread(target=_panel_loop, name="pi052-panel-redraw", daemon=True) panel_thread.start() try: @@ -1296,11 +1292,11 @@ def _run_autonomous( runtime.state.setdefault("events_this_tick", []).append("user_interjection") else: print( - "[smolvla2] no task yet — use /action to start", + "[pi052] no task yet — use /action to start", flush=True, ) except KeyboardInterrupt: - print("\n[smolvla2] interrupt — stopping", flush=True) + print("\n[pi052] interrupt — stopping", flush=True) finally: _panel_stop.set() runtime.stop() @@ -1311,9 +1307,9 @@ def _run_autonomous( time.sleep(0.1) try: robot.disconnect() - print("[smolvla2] robot disconnected", flush=True) + print("[pi052] robot disconnected", flush=True) except Exception as exc: # noqa: BLE001 - print(f"[smolvla2] WARNING: robot.disconnect raised {exc}", flush=True) + print(f"[pi052] WARNING: robot.disconnect raised {exc}", flush=True) return 0 @@ -1340,7 +1336,7 @@ def _make_state_panel_renderer( st = runtime.state run_mode = st.get("mode", "action") mode_tag = "[green]mode: action[/]" if run_mode == "action" else "[yellow]mode: paused[/]" - console.rule(f"[bold]SmolVLA2[/] · {mode_label} · {mode_tag}", style="cyan") + console.rule(f"[bold]PI052[/] · {mode_label} · {mode_tag}", style="cyan") # Always-visible command hint so the operator never has to # remember the slash commands. if run_mode == "action": @@ -1499,14 +1495,14 @@ def main(argv: list[str] | None = None) -> int: autonomous_mode = bool(args.robot_type) and not args.no_robot if autonomous_mode and not args.dataset_repo_id: print( - "[smolvla2] ERROR: autonomous robot mode requires --dataset.repo_id " + "[pi052] ERROR: autonomous robot mode requires --dataset.repo_id " "for action-denormalisation stats and feature shapes. Pass the " "same dataset the policy was trained on.", file=sys.stderr, ) return 2 - print(f"[smolvla2] loading policy from {args.policy_path}", flush=True) + print(f"[pi052] loading policy from {args.policy_path}", flush=True) policy, preprocessor, postprocessor, ds_meta = _load_policy_and_preprocessor( args.policy_path, args.dataset_repo_id ) @@ -1537,7 +1533,7 @@ def main(argv: list[str] | None = None) -> int: ) if chosen: args.task = chosen - print(f"[smolvla2] task: {args.task!r}", flush=True) + print(f"[pi052] task: {args.task!r}", flush=True) # No startup prompts — the runtime is command-driven. It comes up at # the command line in ``paused`` mode (robot idle) unless ``--mode`` @@ -1551,7 +1547,7 @@ def main(argv: list[str] | None = None) -> int: if autonomous_mode: print( - f"[smolvla2] connecting to robot.type={args.robot_type} port={args.robot_port}", + f"[pi052] connecting to robot.type={args.robot_type} port={args.robot_port}", flush=True, ) robot = _build_robot( @@ -1575,7 +1571,7 @@ def main(argv: list[str] | None = None) -> int: ) elif args.dataset_repo_id is not None: print( - f"[smolvla2] streaming observations from {args.dataset_repo_id} " + f"[pi052] streaming observations from {args.dataset_repo_id} " f"episode={args.dataset_episode} " f"start_frame={args.dataset_start_frame}", flush=True, @@ -1592,11 +1588,11 @@ def main(argv: list[str] | None = None) -> int: tools = _build_tools(args.no_tts, args.tts_voice) if tools: - print(f"[smolvla2] tools loaded: {list(tools)}", flush=True) + print(f"[pi052] tools loaded: {list(tools)}", flush=True) - from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415 + from lerobot.policies.pi052.inference import PI052Runtime # noqa: PLC0415 - runtime = SmolVLA2Runtime( + runtime = PI052Runtime( policy=policy, tools=tools, observation_provider=observation_provider, @@ -1662,7 +1658,7 @@ def main(argv: list[str] | None = None) -> int: logger.warning("startup tick failed: %s", exc) startup_logs = [] for line in startup_logs or []: - print(f"[smolvla2] {line}", flush=True) + print(f"[pi052] {line}", flush=True) return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks) @@ -1680,7 +1676,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) from rich.console import Console # noqa: PLC0415 except ImportError: print( - "[smolvla2] rich is required for the interactive REPL. `pip install rich` and re-run.", + "[pi052] rich is required for the interactive REPL. `pip install rich` and re-run.", file=sys.stderr, ) return 2 @@ -1724,7 +1720,7 @@ def _run_repl(runtime: Any, *, initial_task: str | None, max_ticks: int | None) # task to be meaningful. if not runtime.state.get("task"): print( - "[smolvla2] no task yet — use /action ", + "[pi052] no task yet — use /action ", flush=True, ) _redraw(last_logs) diff --git a/src/lerobot/tools/say.py b/src/lerobot/tools/say.py index a5f2c5f89..6d1079a0b 100644 --- a/src/lerobot/tools/say.py +++ b/src/lerobot/tools/say.py @@ -13,10 +13,9 @@ # limitations under the License. """``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts. -The first concrete tool implementation. SmolVLA2 (PR 3) and downstream -runtime dispatchers consume this when the model emits an assistant -message with ``tool_calls=[{function: {name: "say", arguments: -{text: ...}}}]``. +The first concrete tool implementation. PI052 and downstream runtime +dispatchers consume this when the model emits an assistant message +with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``. Why pocket-tts: diff --git a/tests/policies/pi052/test_pi052_vqa_loc.py b/tests/policies/pi052/test_pi052_vqa_loc.py index 9207e4eb4..596778aed 100644 --- a/tests/policies/pi052/test_pi052_vqa_loc.py +++ b/tests/policies/pi052/test_pi052_vqa_loc.py @@ -162,7 +162,7 @@ def test_messages_vqa_to_loc_noop_without_target_indices(): def test_loc_round_trip_keypoint_preserves_normalized_coords(): - from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer + from lerobot.policies.pi052.inference.vqa import parse_vqa_answer answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]} loc = _vqa_answer_to_loc(answer) @@ -175,7 +175,7 @@ def test_loc_round_trip_keypoint_preserves_normalized_coords(): def test_loc_round_trip_bbox_preserves_order_and_scale(): - from lerobot.policies.smolvla2.inference.vqa import parse_vqa_answer + from lerobot.policies.pi052.inference.vqa import parse_vqa_answer answer = { "detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}] diff --git a/tests/policies/smolvla/test_smolvla2_attention_masking.py b/tests/policies/smolvla/test_smolvla2_attention_masking.py deleted file mode 100644 index 7811927ed..000000000 --- a/tests/policies/smolvla/test_smolvla2_attention_masking.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Attention-masking tests for the SmolVLA2 text head. - -Regression coverage for the text-CE collapse bug: ``embed_prefix`` flags -every language token ``att=0``, which ``make_att_2d_masks`` turns into a -single fully *bidirectional* block. Under that mask the text -cross-entropy degenerates into a copy task — a supervised target token -attends to the tokens it is trained to predict — and the model never -learns causal generation, so ``select_message`` collapses at inference. - -``_mark_target_span_causal`` sets ``att=1`` on the supervised target -language positions so each target token attends causally among the -targets while staying bidirectional to images + the user prompt. These -tests pin that behaviour. -""" - -import pytest -import torch - -# The smolvla2 modeling module imports transformers transitively. -pytest.importorskip("transformers") - -from lerobot.policies.smolvla.modeling_smolvla import make_att_2d_masks # noqa: E402 -from lerobot.policies.smolvla2.modeling_smolvla2 import ( # noqa: E402 - _locate_lang_range, - _mark_target_span_causal, -) - -# --------------------------------------------------------------------------- -# A synthetic SmolVLA prefix layout: [images, prompt-lang, target-lang, state] -# -# indices 0-1 : 2 image tokens (att = 0) -# indices 2-4 : 3 user-prompt lang (att = 0) -# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix) -# index 9 : 1 state token (att = 1) -# -# ``text_labels`` covers the 7 language tokens; -100 on the prompt span, -# real ids on the 4-token target span. -# --------------------------------------------------------------------------- -N_IMAGE = 2 -N_PROMPT = 3 -N_TARGET = 4 -LANG_START = N_IMAGE -LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = state-token index -PREFIX_LEN = LANG_END + 1 - - -def _embed_prefix_att_masks() -> torch.Tensor: - """Mimic ``embed_prefix``: images + lang all att=0, state att=1.""" - att = torch.zeros(1, PREFIX_LEN, dtype=torch.bool) - att[0, LANG_END] = True # the single state token - return att - - -def _text_labels() -> torch.Tensor: - """-100 over the prompt span, real ids over the target span.""" - labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long) - labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET) - return labels - - -def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor: - """2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j.""" - pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool) - return make_att_2d_masks(pad, prefix_att_masks)[0] - - -def test_locate_lang_range_anchors_on_state_token(): - """``_locate_lang_range`` finds the lang span via the lone att=1 token.""" - lang_start, lang_end = _locate_lang_range( - _embed_prefix_att_masks(), num_lang=N_PROMPT + N_TARGET - ) - assert (lang_start, lang_end) == (LANG_START, LANG_END) - - -def test_mark_sets_att_on_targets_only(): - """Only the supervised target language positions flip to att=1.""" - marked = _mark_target_span_causal( - _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END - ) - expected = [False] * PREFIX_LEN - for i in range(LANG_START + N_PROMPT, LANG_END): # target span - expected[i] = True - expected[LANG_END] = True # state token, untouched - assert marked[0].tolist() == expected - - -def test_target_tokens_attend_causally_among_themselves(): - """A target token must NOT attend to later targets, but must attend - to earlier ones — i.e. genuine causal next-token prediction.""" - marked = _mark_target_span_causal( - _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END - ) - attends = _attends(marked) - tgt = range(LANG_START + N_PROMPT, LANG_END) - for i in tgt: - for j in tgt: - if j > i: - assert not attends[i, j], f"target {i} must not see future target {j}" - else: - assert attends[i, j], f"target {i} must see earlier/self target {j}" - - -def test_target_tokens_attend_prompt_and_images_bidirectionally(): - """Targets keep full visibility of images + the user prompt.""" - marked = _mark_target_span_causal( - _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END - ) - attends = _attends(marked) - context = list(range(0, LANG_START + N_PROMPT)) # images + prompt - for i in range(LANG_START + N_PROMPT, LANG_END): - for j in context: - assert attends[i, j], f"target {i} must attend context {j}" - - -def test_action_expert_token_still_sees_full_subtask(): - """The state token (action-expert context) attends to every target — - causal masking the targets must not hide them from the action path.""" - marked = _mark_target_span_causal( - _embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END - ) - attends = _attends(marked) - for j in range(LANG_START + N_PROMPT, LANG_END): - assert attends[LANG_END, j], f"state token must see target {j}" - - -def test_non_target_subtask_stays_bidirectional(): - """``low_level_execution`` renders the subtask as a user turn — its - ``text_labels`` are all -100, so the mask must be left untouched and - the action expert reads the subtask bidirectionally.""" - all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long) - marked = _mark_target_span_causal( - _embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END - ) - assert torch.equal(marked, _embed_prefix_att_masks()) - - -def test_unmarked_mask_is_bidirectional_the_bug(): - """Documents the bug the fix prevents: without ``_mark_target_span_causal`` - a target token attends *bidirectionally* to later targets — the - text-CE can copy the answer it is trained to predict.""" - attends = _attends(_embed_prefix_att_masks()) - first_tgt = LANG_START + N_PROMPT - last_tgt = LANG_END - 1 - assert attends[first_tgt, last_tgt], ( - "raw embed_prefix mask is bidirectional over language — the first " - "target token can see the last, which is the collapse bug" - ) diff --git a/tests/policies/smolvla/test_smolvla2_chat_processor.py b/tests/policies/smolvla/test_smolvla2_chat_processor.py deleted file mode 100644 index 26735affe..000000000 --- a/tests/policies/smolvla/test_smolvla2_chat_processor.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for SmolVLA2's chat-tokenizer ``tool_calls`` flattening. - -``_split_plan_and_say`` (inference) expects the model to emit a textual -``...`` marker. ``_flatten_say_tool_calls`` is the training-time -serializer that produces it: it rewrites an assistant turn's structured -``say`` tool call into that marker *inside the content text*, before -``apply_chat_template`` runs — so the chat template only tokenizes plain -text and the supervised target span trains the model to emit the marker -the runtime parses back. These tests pin the round-trip. -""" - -from lerobot.policies.smolvla2.chat_processor_smolvla2 import flatten_say_tool_calls -from lerobot.policies.smolvla2.inference.steps import _split_plan_and_say - - -def _say_call(text): - return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}} - - -def test_flatten_appends_say_marker_and_drops_tool_calls(): - msg = {"role": "assistant", "content": "Pick up the blue cube.", "tool_calls": [_say_call("On it!")]} - out = flatten_say_tool_calls(msg) - assert "tool_calls" not in out - assert out["content"] == "Pick up the blue cube.\nOn it!" - - -def test_flatten_roundtrips_through_inference_parser(): - """The marker the serializer writes must be exactly what the inference - parser reads back — this is the train/inference contract.""" - msg = {"role": "assistant", "content": "Move toward the cube.", "tool_calls": [_say_call("Working on it")]} - flat = flatten_say_tool_calls(msg)["content"] - plan, speech = _split_plan_and_say(flat) - assert plan == "Move toward the cube." - assert speech == "Working on it" - - -def test_flatten_accepts_json_string_arguments(): - """``arguments`` may arrive as a JSON string rather than a dict.""" - call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}} - out = flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]}) - assert out["content"] == "p\nhello there" - - -def test_flatten_leaves_messages_without_tool_calls_untouched(): - msg = {"role": "assistant", "content": "just a plan"} - assert flatten_say_tool_calls(msg) == msg - - -def test_flatten_drops_empty_or_non_say_tool_calls(): - """A non-``say`` call (or empty text) leaves content alone but still - strips the structured calls so the template renders no JSON block.""" - weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}} - out = flatten_say_tool_calls({"role": "assistant", "content": "plan only", "tool_calls": [weather]}) - assert out["content"] == "plan only" - assert "tool_calls" not in out - - -def test_flatten_marker_only_when_content_empty(): - msg = {"role": "assistant", "content": "", "tool_calls": [_say_call("hi")]} - out = flatten_say_tool_calls(msg) - assert out["content"] == "hi" diff --git a/tests/policies/smolvla/test_smolvla2_vqa_overlay.py b/tests/policies/smolvla/test_smolvla2_vqa_overlay.py deleted file mode 100644 index 987574f8f..000000000 --- a/tests/policies/smolvla/test_smolvla2_vqa_overlay.py +++ /dev/null @@ -1,228 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for the SmolVLA2 runtime's interactive-VQA helpers. - -Covers camera selection, VQA-answer parsing, and the bounding-box / -keypoint overlay drawing — the pure functions, no model load. -""" - -import numpy as np -import pytest - -from lerobot.policies.smolvla2.inference.vqa import ( - answer_has_overlay, - available_cameras, - camera_short_name, - draw_vqa_overlay, - observation_image_to_pil, - parse_vqa_answer, - prompt_camera_choice, -) - -PIL = pytest.importorskip("PIL") -from PIL import Image # noqa: E402 - -# --------------------------------------------------------------------------- -# Camera selection -# --------------------------------------------------------------------------- - - -def test_available_cameras_extracts_and_sorts_image_keys(): - observation = { - "observation.images.wrist": object(), - "observation.state": object(), - "observation.images.top": object(), - "task": "x", - } - assert available_cameras(observation) == [ - "observation.images.top", - "observation.images.wrist", - ] - - -def test_available_cameras_handles_none_and_empty(): - assert available_cameras(None) == [] - assert available_cameras({}) == [] - - -def test_camera_short_name_strips_prefix(): - assert camera_short_name("observation.images.top") == "top" - assert camera_short_name("top") == "top" - - -def test_prompt_camera_choice_single_camera_auto_selects(): - cams = ["observation.images.top"] - # input_fn must never be called for a single-camera setup. - chosen = prompt_camera_choice(cams, input_fn=_boom, print_fn=lambda *_: None) - assert chosen == "observation.images.top" - - -def test_prompt_camera_choice_by_number(): - cams = ["observation.images.top", "observation.images.wrist"] - chosen = prompt_camera_choice(cams, input_fn=lambda _: "2", print_fn=lambda *_: None) - assert chosen == "observation.images.wrist" - - -def test_prompt_camera_choice_by_name(): - cams = ["observation.images.top", "observation.images.wrist"] - chosen = prompt_camera_choice(cams, input_fn=lambda _: "top", print_fn=lambda *_: None) - assert chosen == "observation.images.top" - - -def test_prompt_camera_choice_invalid_returns_none(): - cams = ["observation.images.top", "observation.images.wrist"] - assert prompt_camera_choice(cams, input_fn=lambda _: "99", print_fn=lambda *_: None) is None - - -def _boom(*_args, **_kwargs): - raise AssertionError("input_fn should not be called") - - -# --------------------------------------------------------------------------- -# Answer parsing -# --------------------------------------------------------------------------- - - -def test_parse_bbox_answer(): - answer = '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}' - parsed = parse_vqa_answer(answer) - assert parsed["kind"] == "bbox" - assert answer_has_overlay(parsed) - - -def test_parse_keypoint_answer(): - answer = '{"label": "blue cube", "point_format": "xy", "point": [120, 90]}' - parsed = parse_vqa_answer(answer) - assert parsed["kind"] == "keypoint" - assert answer_has_overlay(parsed) - - -def test_parse_count_answer_is_not_an_overlay(): - parsed = parse_vqa_answer('{"label": "cubes", "count": 2}') - assert parsed["kind"] == "count" - assert not answer_has_overlay(parsed) - - -def test_parse_invalid_json_returns_none(): - assert parse_vqa_answer("not json at all") is None - assert parse_vqa_answer("") is None - # A JSON array is valid JSON but not a VQA answer object. - assert parse_vqa_answer("[1, 2, 3]") is None - - -def test_parse_unknown_shape(): - parsed = parse_vqa_answer('{"weird": "payload"}') - assert parsed["kind"] == "unknown" - assert not answer_has_overlay(parsed) - - -# --------------------------------------------------------------------------- -# Overlay drawing -# --------------------------------------------------------------------------- - - -def _blank(size=(160, 120)): - return Image.new("RGB", size, (0, 0, 0)) - - -def test_draw_bbox_overlay_changes_pixels_and_preserves_size(): - img = _blank() - parsed = parse_vqa_answer( - '{"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]}' - ) - out = draw_vqa_overlay(img, parsed) - assert out.size == img.size - assert out.tobytes() != img.tobytes() - - -def test_draw_keypoint_overlay_changes_pixels(): - img = _blank() - parsed = parse_vqa_answer('{"label": "cube", "point_format": "xy", "point": [80, 60]}') - out = draw_vqa_overlay(img, parsed) - assert out.size == img.size - assert out.tobytes() != img.tobytes() - - -def test_draw_overlay_non_spatial_leaves_image_unchanged(): - img = _blank() - parsed = parse_vqa_answer('{"label": "cubes", "count": 2}') - out = draw_vqa_overlay(img, parsed) - assert out.tobytes() == img.tobytes() - - -def test_draw_overlay_tolerates_malformed_coordinates(): - img = _blank() - # bbox with the wrong arity must not raise. - out = draw_vqa_overlay(img, {"kind": "bbox", "payload": {"detections": [{"bbox": [1, 2]}]}}) - assert out.size == img.size - - -def test_observation_image_to_pil_from_batched_float_array(): - # (1, C, H, W) float array in [0, 1], the runtime observation shape. - arr = np.zeros((1, 3, 24, 32), dtype=np.float32) - pil = observation_image_to_pil(arr) - assert pil.size == (32, 24) - assert pil.mode == "RGB" - - -# --------------------------------------------------------------------------- -# PaliGemma -format answers (PI052 trains spatial VQA in this vocab) -# --------------------------------------------------------------------------- - - -def test_parse_loc_keypoint_answer(): - # label — y=512/1023≈0.5, x=256/1023≈0.25 - parsed = parse_vqa_answer(" blue cube") - assert parsed["kind"] == "keypoint" - assert parsed["normalized"] is True - x, y = parsed["payload"]["point"] - assert 0.24 < x < 0.26 - assert 0.49 < y < 0.51 - assert parsed["payload"]["label"] == "blue cube" - assert answer_has_overlay(parsed) - - -def test_parse_loc_bbox_answer(): - # label - parsed = parse_vqa_answer(" yellow cube") - assert parsed["kind"] == "bbox" - assert parsed["normalized"] is True - det = parsed["payload"]["detections"][0] - x1, y1, x2, y2 = det["bbox"] - assert x1 < x2 and y1 < y2 - assert det["label"] == "yellow cube" - assert answer_has_overlay(parsed) - - -def test_parse_loc_multiple_boxes(): - answer = " cube ; box" - parsed = parse_vqa_answer(answer) - assert parsed["kind"] == "bbox" - assert len(parsed["payload"]["detections"]) == 2 - - -def test_parse_loc_takes_precedence_over_json(): - # An answer with tokens is parsed as loc even if JSON-ish. - assert parse_vqa_answer('{"x": }')["normalized"] is True - - -def test_draw_loc_overlay_denormalizes_to_pixels(): - img = _blank((200, 100)) - parsed = parse_vqa_answer(" cube") # ~centre - out = draw_vqa_overlay(img, parsed) - assert out.size == img.size - assert out.tobytes() != img.tobytes()