diff --git a/Makefile b/Makefile index e02f02403..d3987101f 100644 --- a/Makefile +++ b/Makefile @@ -178,3 +178,9 @@ test-smolvla-ete-eval: --env.episode_length=5 \ --eval.n_episodes=1 \ --eval.batch_size=1 + +# E2E annotation pipeline smoke test against a tiny in-memory fixture +# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM +# backend, so it does not require a real model checkpoint or GPU. +annotation-e2e: + uv run python -m tests.annotations.run_e2e_smoke diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0d4e36172..5d847a94d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -45,6 +45,8 @@ title: Language Columns and Recipes - local: tools title: Tools + - local: annotation_pipeline + title: Annotation Pipeline - local: video_encoding_parameters title: Video encoding parameters - local: streaming_video_encoding diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx new file mode 100644 index 000000000..02658ec9a --- /dev/null +++ b/docs/source/annotation_pipeline.mdx @@ -0,0 +1,291 @@ +# Annotation Pipeline + +`lerobot-annotate` watches each episode's video with a vision-language +model (VLM) and writes natural-language annotations back into your +dataset. It fills the two language columns from the +[Language Columns and Recipes](./language_and_recipes) page — +`language_persistent` and `language_events` — straight into +`data/chunk-*/file-*.parquet`. + +In short: point it at a LeRobot dataset, and it adds subtasks, plans, +memory, interjections, speech, and visual Q&A that a policy can be +trained on. + +## How it fits together + +```text + your dataset lerobot-annotate + (LeRobot v3.1) + │ + ▼ + ┌─────────────────────────────────────────────────────┐ + │ read episodes │ + └──────────────────────────┬──────────────────────────┘ + │ + ┌────────────────────┼────────────────────┐ + ▼ ▼ ▼ + ┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL + │ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI + └────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three + └────────────────────┼─────────────────────┘ + │ each module stages raw JSONL + ▼ into .annotate_staging/ + ┌─────────────────┐ + │ validator │ ◀── checks everything + └────────┬────────┘ + ▼ + ┌─────────────────┐ + │ writer │ + └────────┬────────┘ + ▼ + data/chunk-*/file-*.parquet + (+ meta/info.json tools) +``` + +Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared +VLM. Each module stages its output to disk, a validator checks it, and a +single writer rewrites the dataset shards in place. + +## What the pipeline produces + +Each module emits a few kinds of annotation ("styles"), routed to one of +the two language columns: + +| Style / atom | Column | Module | +| ------------------------------------------- | --------------------- | --------------- | +| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` | +| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` | +| `memory` (MEM-style compression) | `language_persistent` | `plan` | +| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` | +| `interjection` | `language_events` | `interjections` | +| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` | +| `vqa` (user / assistant pair) | `language_events` | `vqa` | + +### How subtasks are generated + +The `plan` module doesn't ask the VLM for subtasks in one shot. Instead +it uses a two-step **describe → segment** flow: + +1. **Describe** — the VLM narrates only what it actually sees in the + chosen camera (no guessing about the task). +2. **Segment** — that description is fed back in, and the VLM splits the + episode into consecutive atomic subtasks. + +Both passes see the episode as **timestamped contact sheets** — frames +sampled at `frames_per_second` (0.5s by default) and packed into JPEG +grids with each frame's time burned into its corner, so the VLM cites +exact boundary times directly. This is far cheaper in vision tokens than +one image per frame, so the sampling can stay dense; episodes longer than +`max_frames_per_prompt` are split into windows at the same density and +merged. Both prompts also carry a causal **event-boundary** definition (a +new event starts when an object becomes held / is released / reaches a new +location / a lid changes state / contents move) to sharpen where cuts land. + +The resulting spans are then stitched into a gap-free, full-episode +cover, so **every frame has exactly one active subtask**. See +[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py) +for the production settings (single camera, timestamped contact sheets, +auto-windowed subtask generation). + +### Tools + +The writer does **not** add a `tools` column to the parquet. The tool +catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)). +After every run, the pipeline makes sure the canonical `say` schema is in +that list, keeping any tools you declared beforehand. + +Want to add your own tool? Edit `meta/info.json["tools"]` directly — the +pipeline preserves whatever is already there. That makes the tool visible +to the chat template, so the model can learn to _generate_ the call. The +runtime layer that actually _executes_ a generated call (the `Tool` +protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of +this PR — the [Tools](./tools) doc marks those pieces as +not-yet-implemented. + +## Running on Hugging Face Jobs + +Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). +The repo ships a launcher script you copy and tweak for your dataset: + +```bash +HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py +``` + +[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py) +starts a single-GPU `h200` job (bump it to `h200x4` for big datasets) +that: + +1. installs `lerobot` (from `main`) plus the annotation extras, +2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and + drives it over the OpenAI-compatible API, +3. runs the `plan` / `interjections` / `vqa` modules across the dataset + with `lerobot-annotate`, +4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or + back to `--repo_id` in place if you leave that unset). + +To use a different dataset, model, or hub repo, edit the `CMD` block in +the script. Every flag there maps directly to a `lerobot-annotate` flag +(run `lerobot-annotate --help` for the full list). + +## Key options + +These are the flags you'll reach for most often. Run +`lerobot-annotate --help` for everything else; the defaults are tuned for +short manipulation episodes. + +### Dataset in / out + +| Flag | Default | What it does | +| ----------------- | ------- | ----------------------------------------------------------------------- | +| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). | +| `--root` | — | Annotate a local dataset directory instead. | +| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). | +| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). | +| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). | +| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. | + +### Which modules run + +Every module is on by default and can be toggled independently (set to +`false` to skip it, e.g. to iterate on one module at a time): + +| Flag | Default | Turns off | +| ------------------------- | ------- | ----------------------------------- | +| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug | +| `--interjections.enabled` | `true` | interjections + speech atoms | +| `--vqa.enabled` | `true` | the VQA pairs | + +### The VLM (`--vlm.*`) + +| Flag | Default | What it does | +| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- | +| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. | +| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. | +| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). | +| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). | +| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). | +| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. | +| `--vlm.max_new_tokens` | `512` | Generation cap per call. | +| `--vlm.temperature` | `0.2` | Sampling temperature. | + +### Subtasks / plan / memory (`--plan.*`) + +| Flag | Default | What it does | +| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- | +| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). | +| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. | +| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). | +| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. | +| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). | +| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). | +| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. | +| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). | +| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). | + +### Interjections + VQA + +| Flag | Default | What it does | +| ----------------------------------------------- | ------- | ---------------------------------------------------------- | +| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. | +| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. | +| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). | +| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. | + +## Contributing new modules + +The pipeline is built to grow, and **contributions are very welcome** — +a brand-new module (say, trajectory traces or affordances), a new prompt +template, a smarter grounding flow, or quality fixes to the existing +`plan` / `interjections` / `vqa` modules. + +Every module lives under +`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM +client and the keyframe cache, writes its raw output to the staging +tree, and plugs into the executor as its own phase. Got an idea? Open an +issue or PR on [the repo](https://github.com/huggingface/lerobot). + +## How recipes consume the output + +The annotations are meant to be read by recipes (see +[Language Columns and Recipes](./language_and_recipes)). Typically: + +- low-level / high-level / memory-update branches read + `subtask` / `plan` / `memory` from `language_persistent`. +- an interjection-response branch reads `interjection` events plus the + paired speech atom (merged into one assistant turn via `tool_calls_from`) + and the matching `plan` refresh at the same timestamp. +- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from + `language_events`. + +## Why state and events are split + +Two ideas shape the design: + +1. **Persistent state vs. exact events.** Persistent rows (`subtask`, + `plan`, `memory`) apply to the whole episode and answer "what's true + right now?". Event rows (`interjection`, `vqa`, speech) appear only on + the one frame whose timestamp matches. Timestamps are copied straight + from the source parquet — never recomputed in floating point. +2. **One VLM pass.** All three modules share a single VLM client (the + OpenAI-compatible client talking to the job's vLLM server), so you pay + for one model load per dataset, not three. + +## Re-running a single module + +Each module stages its raw output to +`/.annotate_staging/episode_{N:06d}/.jsonl`. This makes +prompt iteration cheap: re-running one module overwrites only its own +JSONL, then the writer recomposes the final parquet. Disable modules you +don't want with `--plan.enabled=false` (and likewise +`--interjections.enabled` / `--vqa.enabled`) to test one at a time. + +## What the validator checks + +Before the writer runs, `StagingValidator` confirms: + +- every event row lands exactly on a real frame timestamp; +- no speech / interjection pairs are left orphaned; +- `plan` is refreshed at every interjection timestamp; +- `memory` rows fall on subtask boundaries (a warning, not an error); +- each VQA assistant `content` is valid JSON in one of the + bbox / keypoint / count / attribute / spatial shapes; +- every row goes to the column chosen by `column_for_style(style)`. + +Any error aborts the writer. Pass `--skip_validation=true` to override +while debugging. + +## Where each module's ideas come from + +- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417)) + for atom granularity ("pick up one piece of lettuce", "place bowl to + box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) + for "how, not what" detail. +- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)): + keep only the minimal relevant information — preserve outcomes, drop + specific attributes. +- **`interjections`.** Hi Robot's scenario taxonomy: negative task, + situated correction, specific constraint, preference. Speech is a + tool-call-only atom + (`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`). +- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for + grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`, + keypoints) and Steerable VLA Policies + ([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction + grounding. Pi0.7 also grounds answers across abstraction levels. + +When improving a module, tweak its prompt template in +`src/lerobot/annotations/steerable_pipeline/prompts/` rather than +rewriting from scratch. + +## Roughly how much it costs + +Per episode, the pipeline makes about `max_steps` plan calls, +`max_interjections_per_episode` interjection calls, and +`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8 +subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's +~50 VLM calls. + +Storage stays small: `language_persistent` is at most tens of KB per +episode (parquet dictionary-encodes the one entry that repeats across +frames), and `language_events` is empty on most frames — its size scales +with the number of emissions, not `num_frames × num_emissions`. diff --git a/examples/annotations/run_hf_job.py b/examples/annotations/run_hf_job.py new file mode 100644 index 000000000..a77e22f14 --- /dev/null +++ b/examples/annotations/run_hf_job.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM). + +Spawns one single-GPU ``h200`` job that: + + 1. installs ``lerobot`` from ``main`` plus the annotation extras, + 2. boots one vllm server with Qwen3.6-27B (dense VLM), + 3. runs the plan / interjections / vqa modules across the dataset + in free-form mode (each episode generates its own subtasks + + memory), + 4. uploads the annotated dataset to ``--new_repo_id`` (when set) + or back to ``--repo_id``. + +Usage: + + HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py + +Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your +run. For larger datasets, scale to ``h200x4`` and raise +``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match. +""" + +import os + +from huggingface_hub import get_token, run_job + +token = os.environ.get("HF_TOKEN") or get_token() +if not token: + raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`") + +CMD = ( + "apt-get update -qq && apt-get install -y -qq git ffmpeg && " + "pip install --no-deps " + "'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && " + "pip install --upgrade-strategy only-if-needed " + "datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect " + "openai && " + "export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && " + "export VLLM_VIDEO_BACKEND=pyav && " + "lerobot-annotate " + "--repo_id=pepijn223/robocasa_pretrain_human300_v4 " + "--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated " + "--push_to_hub=true " + "--vlm.backend=openai " + "--vlm.model_id=Qwen/Qwen3.6-27B " + "--vlm.num_gpus=1 " + '--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B ' + "--tensor-parallel-size 1 --max-model-len 32768 " + '--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" ' + "--vlm.serve_ready_timeout_s=1800 " + # Qwen3.6 ships with thinking on; annotation wants plain JSON answers. + "--vlm.chat_template_kwargs='{\"enable_thinking\": false}'" +) + +job = run_job( + image="vllm/vllm-openai:latest", + command=["bash", "-c", CMD], + flavor="h200", + secrets={"HF_TOKEN": token}, + timeout="2h", +) +print(f"Job URL: {job.url}") +print(f"Job ID: {job.id}") diff --git a/pyproject.toml b/pyproject.toml index e43f8ef81..0dc86d7ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,6 +229,21 @@ vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] +# Annotation pipeline (lerobot-annotate). The only backend is ``openai``, +# which talks to any OpenAI-compatible server (``vllm serve`` / +# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs +# (see examples/annotations/run_hf_job.py). +annotations = [ + "lerobot[dataset]", + "lerobot[transformers-dep]", + "openai>=1.40,<2.0", + # ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and + # uv's single unified lock would then cap ``torch`` for every extra + # (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI + # break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM; + # install it locally only if you run your own ``vllm serve``. +] + # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"] @@ -323,6 +338,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" +lerobot-annotate="lerobot.scripts.lerobot_annotate:main" lerobot-rollout="lerobot.scripts.lerobot_rollout:main" # ---------------- Tool Configurations ---------------- @@ -341,7 +357,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] [tool.setuptools.package-data] -lerobot = ["envs/*.json"] +lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"] [tool.setuptools.packages.find] where = ["src"] diff --git a/src/lerobot/annotations/__init__.py b/src/lerobot/annotations/__init__.py new file mode 100644 index 000000000..67782f192 --- /dev/null +++ b/src/lerobot/annotations/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/annotations/steerable_pipeline/__init__.py b/src/lerobot/annotations/steerable_pipeline/__init__.py new file mode 100644 index 000000000..a8da5e05e --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/__init__.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Steerable annotation pipeline producing ``language_persistent`` and +``language_events`` columns for LeRobot datasets. + +The pipeline is decomposed into three independently runnable modules whose +outputs are staged per-episode before a final parquet rewrite: + +- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles +- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech +- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs +""" + +from .config import AnnotationPipelineConfig +from .validator import StagingValidator, ValidationReport +from .writer import LanguageColumnsWriter + +__all__ = [ + "AnnotationPipelineConfig", + "LanguageColumnsWriter", + "StagingValidator", + "ValidationReport", +] diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py new file mode 100644 index 000000000..86d6cadd9 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class PlanConfig: + """``plan`` module: subtasks + plan + memory + task augmentation.""" + + enabled: bool = True + + # ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables. + n_task_rephrasings: int = 10 + + # Derive the task from video instead of episode_task: off / if_short / always. + # Affects prompts only; ``meta/tasks.parquet`` is untouched. + derive_task_from_video: str = "if_short" + derive_task_min_words: int = 3 + + # --- Frame input: timestamped contact sheets (always on) --------------- + # The subtask describe/segment passes ALWAYS render the episode as + # macrodata/refiner-style contact sheets: sampled frames packed into JPEG + # grids with each frame's timestamp burned into its corner, so the VLM + # cites the exact source time of a boundary directly. This is far cheaper + # in vision tokens than one image per frame (≈2× faster subtask generation + # in practice), which is why the sampling is dense by default. + # + # ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s. + frames_per_second: float = 2.0 + # Frame budget per VLM call (= columns × rows × sheets). When a whole + # episode sampled at ``frames_per_second`` exceeds this, the episode is + # AUTOMATICALLY split into consecutive windows of + # ``max_frames_per_prompt`` frames each (one describe→segment call per + # window, still at the full ``frames_per_second`` density), and the + # per-window spans are merged + stitched into one contiguous cover. So an + # episode of any length is always covered at the full sampling density. + max_frames_per_prompt: int = 60 + contact_sheet_columns: int = 5 + contact_sheet_frames_per_sheet: int = 20 + contact_sheet_frame_width: int = 224 + contact_sheet_quality: int = 84 + + min_subtask_seconds: float = 1.5 + plan_max_steps: int = 8 + + # Narrate-only grounding pass before segmenting — best defense against subtasks + # invented from the task text (+1 VLM call/episode). + subtask_describe_first: bool = True + + # Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only. + emit_plan: bool = True + + # Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only. + # Symmetric counterpart of ``emit_plan``. + emit_memory: bool = True + + # (subtask spans are always stitched to a contiguous full-episode cover; not configurable.) + + # Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings. + task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig()) + + +@dataclass +class TaskAugAxesConfig: + """5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm / + omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings + when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to + omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic.""" + + enabled: bool = False + + synonym_paraphrase: int = 3 + omit_arm: int = 3 + omit_orientation: int = 2 + omit_grasp_method: int = 2 + combined_omissions: int = 2 + + +@dataclass +class InterjectionsConfig: + """``interjections`` module: interjections + paired speech.""" + + enabled: bool = True + + # Each emits a paired (interjection, speech) row + a plan refresh at that ts. + max_interjections_per_episode: int = 3 + interjection_min_t: float = 2.0 + + # Frame window centered on the timestamp so the VLM sees motion, not one frame. + interjection_window_seconds: float = 2.0 + interjection_window_frames: int = 4 + + +@dataclass +class VqaConfig: + """``vqa`` module: general VQA.""" + + enabled: bool = True + vqa_emission_hz: float = 1.0 + K: int = 1 + """Consecutive frames per emission tick. The VLM grounds on the FIRST frame, + so K>1 smears stale labels onto moved frames. Default 1 (no smear).""" + question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial") + + # True: ground VQA only on --vlm.camera_key (default: every camera). + restrict_to_default_camera: bool = False + + +@dataclass +class VlmConfig: + """Shared Qwen-VL client configuration.""" + + # Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when + # auto_serve=True); ``stub`` is for tests. + backend: str = "openai" + model_id: str = "Qwen/Qwen3.6-27B" + + # OpenAI-compatible endpoint; ``EMPTY`` key works for local servers. + api_base: str = "http://localhost:8000/v1" + api_key: str = "EMPTY" + + # Spawn a server if none answers api_base; False = fail fast on a remote. + auto_serve: bool = True + serve_port: int = 8000 + # Override the auto-serve command; ``{port}`` substituted per replica. + serve_command: str | None = None + + # Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each. + parallel_servers: int = 1 + num_gpus: int = 0 + client_concurrency: int = 16 + serve_ready_timeout_s: float = 600.0 + + max_new_tokens: int = 512 + temperature: float = 0.2 + + # Auto-serve context length (None → 32768); other vLLM flags go in serve_command. + max_model_len: int | None = None + + # Camera for keyframes; None → first ``observation.images.*`` key. + camera_key: str | None = None + # Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}). + chat_template_kwargs: dict[str, Any] | None = None + + +@dataclass +class ExecutorConfig: + """Executor settings (intra-process episode concurrency; distribution via HF Jobs).""" + + # Episodes processed concurrently per phase; main knob for saturating the servers. + episode_parallelism: int = 16 + + +@dataclass +class AnnotationPipelineConfig: + """Top-level config for ``lerobot-annotate`` (rewrites data shards in place).""" + + # Hub dataset: download source when ``root`` unset; push target when push_to_hub + # is on and ``new_repo_id`` unset. + repo_id: str | None = None + + # Separate push target (matches the LeRobot edit tools). Unset → push in place. + new_repo_id: str | None = None + + root: Path | None = None + + # Defaults to ``/.annotate_staging/``. + staging_dir: Path | None = None + + seed: int = 1729 + + plan: PlanConfig = field(default_factory=PlanConfig) + interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig) + vqa: VqaConfig = field(default_factory=VqaConfig) + + vlm: VlmConfig = field(default_factory=VlmConfig) + executor: ExecutorConfig = field(default_factory=ExecutorConfig) + + skip_validation: bool = False + only_episodes: tuple[int, ...] | None = None + + # Keyframe decode backend forwarded to ``decode_video_frames``. None → + # library default (torchcodec when available, else PyAV). Or pin + # ``"torchcodec"`` / ``"pyav"`` explicitly. + video_backend: str | None = None + + # Upload to the Hub (new_repo_id if set, else repo_id; one must be set). + push_to_hub: bool = False + push_private: bool = False + push_commit_message: str | None = None + + def resolved_staging_dir(self, root: Path) -> Path: + return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging" diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py new file mode 100644 index 000000000..69d10bc89 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""In-process executor that runs the annotation phases. + +The executor runs **six phases** in dependency order: + + phase 1: ``plan`` module (plan + subtasks + memory) + phase 2: ``interjections`` module (interjections + speech) + phase 3: ``plan`` plan-update pass — re-runs plan emission at every + interjection timestamp produced by phase 2 + phase 4: ``vqa`` module (VQA) + phase 5: validator + phase 6: writer + +Phase 3 is why the ``plan`` module must be re-entered after the +``interjections`` module — to refresh ``plan`` rows at interjection +timestamps. + +Distributed execution is provided by Hugging Face Jobs (see +``examples/annotations/run_hf_job.py``); the runner inside the job +invokes ``lerobot-annotate`` which uses this in-process executor. +Episode-level concurrency is controlled by +``ExecutorConfig.episode_parallelism``. +""" + +from __future__ import annotations + +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .config import AnnotationPipelineConfig +from .reader import EpisodeRecord, iter_episodes +from .staging import EpisodeStaging +from .validator import StagingValidator +from .writer import LanguageColumnsWriter + +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseResult: + """Summary of one pipeline phase across all episodes.""" + + name: str + episodes_processed: int + episodes_skipped: int + + +@dataclass +class PipelineRunSummary: + """Aggregated result returned by :meth:`Executor.run`.""" + + phases: list[PhaseResult] + written_paths: list[Path] + validation_report: Any # ValidationReport, kept Any to avoid import cycle + + +@dataclass +class Executor: + """Run all six phases over a dataset root in-process. + + Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism`` + (a thread pool); cluster-level concurrency comes from running this + executor inside a Hugging Face Job. Tests construct the executor + directly with stub modules. + """ + + config: AnnotationPipelineConfig + plan: Any # PlanSubtasksMemoryModule + interjections: Any # InterjectionsAndSpeechModule + vqa: Any # GeneralVqaModule + writer: LanguageColumnsWriter + validator: StagingValidator + + def run(self, root: Path) -> PipelineRunSummary: + records = list(iter_episodes(root, only_episodes=self.config.only_episodes)) + n = len(records) + if n == 0: + raise ValueError(f"No episodes found under {root}/data/") + + print(f"[annotate] {n} episodes total", flush=True) + + staging_dir = self.config.resolved_staging_dir(root) + staging_dir.mkdir(parents=True, exist_ok=True) + + phases: list[PhaseResult] = [] + + # Phase 1: ``plan`` module (plan + subtasks + memory) + phases.append(self._run_module_phase("plan", records, staging_dir, self.plan)) + # Phase 2: ``interjections`` module (interjections + speech). It + # reads the ``plan`` module's subtask rows from the same staging + # tree to ground the interjection prompt in the correct local subtask. + phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections)) + # Phase 3: ``plan`` plan-update pass at interjection timestamps. + phases.append(self._run_plan_update_phase(records, staging_dir)) + # Phase 4: ``vqa`` module (VQA) + phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa)) + + print("[annotate] running validator...", flush=True) + report = self.validator.validate(records, staging_dir) + if not report.ok and not self.config.skip_validation: + raise RuntimeError(f"Staging validation failed: {report.summary()}") + print(f"[annotate] validator: {report.summary()}", flush=True) + + print(f"[annotate] writing parquet shards into {root}/data/...", flush=True) + written = self.writer.write_all(records, staging_dir, root) + print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True) + + # Keep meta/info.json aligned with the parquet schema we just wrote. + # Idempotent and additive: existing user metadata is preserved. + self._ensure_annotation_metadata_in_info(root) + + return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) + + @staticmethod + def _ensure_annotation_metadata_in_info(root: Path) -> None: + """Write language features and canonical tools to ``meta/info.json``. + + ``LanguageColumnsWriter`` adds ``language_persistent`` and + ``language_events`` to parquet shards. The metadata must advertise + those columns too, otherwise non-streaming ``LeRobotDataset`` loads + cast against the old schema and fail on the extra parquet columns. + """ + from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415 + from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415 + + info_path = root / "meta" / "info.json" + if not info_path.exists(): + return + try: + info = load_info(root) + except Exception as exc: # noqa: BLE001 + print(f"[annotate] could not read {info_path}: {exc}", flush=True) + return + + changed = False + + merged_features = {**info.features, **language_feature_info()} + if merged_features != info.features: + info.features = merged_features + changed = True + + existing = info.tools or [] + names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)} + if SAY_TOOL_SCHEMA["function"]["name"] not in names: + info.tools = [*existing, SAY_TOOL_SCHEMA] + changed = True + + if changed: + write_info(info, root) + print( + "[annotate] meta/info.json: " + f"language_features={list(language_feature_info())}, " + f"tools={[t['function']['name'] for t in (info.tools or [])]}", + flush=True, + ) + + def _run_module_phase( + self, + name: str, + records: list[EpisodeRecord], + staging_dir: Path, + module: Any, + ) -> PhaseResult: + if not module.enabled: + print(f"[annotate] phase={name} skipped (module disabled)", flush=True) + return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records)) + n = len(records) + parallelism = max(1, min(self.config.executor.episode_parallelism, n)) + print( + f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})", + flush=True, + ) + t0 = time.time() + + def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]: + i, record = idx_record + ep_start = time.time() + staging = EpisodeStaging(staging_dir, record.episode_index) + module.run_episode(record, staging) + return i, record.episode_index, time.time() - ep_start + + processed = 0 + if parallelism == 1: + for i, record in enumerate(records, 1): + _, ep_idx, elapsed = _do((i, record)) + processed += 1 + print( + f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s", + flush=True, + ) + else: + with ThreadPoolExecutor(max_workers=parallelism) as pool: + futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)] + for fut in as_completed(futures): + i, ep_idx, elapsed = fut.result() + processed += 1 + print( + f"[annotate] {name} episode {processed}/{n} " + f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s", + flush=True, + ) + total = time.time() - t0 + print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True) + return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0) + + def _run_plan_update_phase( # noqa: PLR0915 + self, records: list[EpisodeRecord], staging_dir: Path + ) -> PhaseResult: + """Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced. + + The ``plan`` module owns the prompt; the ``interjections`` module + produced the timestamps. This phase therefore calls back into the + ``plan`` module with the interjection timestamps so its existing + prompt path is reused. + """ + if not self.plan.enabled or not self.interjections.enabled: + return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records)) + processed = 0 + for record in records: + staging = EpisodeStaging(staging_dir, record.episode_index) + interjection_rows = [ + row for row in staging.read("interjections") if row.get("style") == "interjection" + ] + interjection_times = [float(row["timestamp"]) for row in interjection_rows] + interjection_texts = [str(row.get("content") or "") for row in interjection_rows] + if interjection_times: + self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts) + processed += 1 + # Episodes without any interjections are skipped (no plan refresh + # needed); count them so the summary's processed+skipped == total. + return PhaseResult( + name="plan_update", + episodes_processed=processed, + episodes_skipped=len(records) - processed, + ) diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py new file mode 100644 index 000000000..a6c904673 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keyframe extraction for the annotation pipeline. + +Modules attach decoded camera frames to their VLM prompts so the model can +ground subtask decomposition, interjection scenarios, and VQA in actual +visual content. The pipeline shares one provider across modules and one +episode at a time, with a small per-episode cache so multiple modules +querying the same timestamp pay decode cost once. +""" + +from __future__ import annotations + +import io +import logging +import math +import threading +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Protocol + +import PIL.Image +import torch + +from lerobot.configs.video import VideoEncoderConfig +from lerobot.datasets.video_utils import decode_video_frames, reencode_video + +from .reader import EpisodeRecord, snap_to_frame + +logger = logging.getLogger(__name__) + + +class FrameProvider(Protocol): + """Decodes camera frames at episode-relative timestamps.""" + + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` feature keys this provider can decode.""" + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + """Return one decoded frame per timestamp from ``camera_key`` (or default). + + Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape + :func:`lerobot.datasets.video_utils.decode_video_frames` returns. + :func:`to_image_blocks` converts them to PIL only at the VLM-message + boundary. + + Empty list if the camera is unavailable. ``camera_key=None`` falls back + to the provider's default camera so existing single-camera callers + (the ``plan`` and ``interjections`` modules) keep working unchanged. + """ + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + """Return up to ``max_frames`` decoded frames covering the whole episode. + + Sampling is uniform across the episode duration. Frames are + ``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps + them into one ``{"type":"video", "video":}`` block for a + Qwen-VL-compatible model that pools temporally itself. Empty list if + no camera available. + """ + + +@dataclass +class _NullProvider: + """No-op provider used when the dataset has no video keys or in tests.""" + + @property + def camera_keys(self) -> list[str]: + return [] + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + return [] + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + return [] + + +def null_provider() -> FrameProvider: + return _NullProvider() + + +@dataclass +class VideoFrameProvider: + """Decodes frames from the dataset's ``observation.images.*`` streams. + + By default the *first* camera key is used for the ``plan`` module + (subtask decomposition) and the ``interjections`` module (interjection + scenarios) — those prompts care about *what is happening*, not which + angle. The ``vqa`` module instead iterates over every camera in + :attr:`camera_keys` so each frame's + grounded answer (bbox/keypoint/...) is tagged with the camera it was + grounded against. + + ``camera_key`` overrides the default-camera choice but does not restrict + :attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` / + ``video_for_episode`` to read a non-default stream. + + Caches up to ``cache_size`` decoded frames per process to keep + co-timestamped ``interjections`` + ``plan`` plan-update calls cheap. + """ + + root: Path + camera_key: str | None = None + tolerance_s: float = 1e-2 + cache_size: int = 256 + # Keyframe decode backend forwarded to + # :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None`` + # uses the library default (torchcodec when available, else PyAV). + video_backend: str | None = None + _meta: Any = field(default=None, init=False, repr=False) + _cache: dict = field(default_factory=dict, init=False, repr=False) + _camera_keys: list[str] = field(default_factory=list, init=False, repr=False) + # Pipeline runs the three module phases under a ThreadPoolExecutor (see + # ``ExecutorConfig.episode_parallelism``); guard the dict cache and the + # one-shot warn flag against concurrent updates from worker threads. + _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) + # Serializes decode_video_frames calls: torchcodec hands out one + # ``VideoDecoder`` per file from a process-wide cache, and the decoder + # is not safe to drive from multiple threads at once. + _decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) + _warned_decode_fail: bool = field(default=False, init=False, repr=False) + + def __post_init__(self) -> None: + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415 + + self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root) + # Only ``video_keys`` are decodable here: the clip/decode paths read + # ``videos//from_timestamp`` from episode metadata, which exists + # only for video-stored cameras. Image-stored cameras (also in + # ``camera_keys``) would KeyError, so restrict the list — and the + # default — to video keys. + keys = list(self._meta.video_keys) + # Last-resort fallback: if metadata didn't surface any video keys but + # the caller explicitly named a camera (``--vlm.camera_key=...``), + # trust them — the key is by definition known to exist on the dataset. + if not keys and self.camera_key: + keys = [self.camera_key] + self._camera_keys = keys + if self.camera_key is None: + self.camera_key = keys[0] if keys else None + + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` keys available on this dataset.""" + return list(self._camera_keys) + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + target = camera_key if camera_key is not None else self.camera_key + if not timestamps or target is None: + return [] + # Snap each request to the nearest real frame timestamp: callers + # sample uniform grids whose points land mid-frame, and + # ``decode_video_frames`` rejects queries farther than + # ``tolerance_s`` from a decodable frame. Snapping also dedupes + # repeat queries through the cache. + if record.frame_timestamps: + timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps] + + out: list[Any] = [] + misses: list[float] = [] + miss_indices: list[int] = [] + with self._lock: + for i, ts in enumerate(timestamps): + key = (record.episode_index, target, round(float(ts), 6)) + cached = self._cache.get(key) + if cached is not None: + out.append(cached) + else: + out.append(None) + misses.append(float(ts)) + miss_indices.append(i) + + if misses: + decoded = self._decode(record.episode_index, misses, target) + # ``_decode`` returns exactly one frame per requested timestamp, + # or an empty list if decoding failed wholesale. A partial list + # would mean a frame/timestamp misalignment, so only pair them up + # when the counts match (``strict=True`` then guards regressions). + if len(decoded) == len(miss_indices): + with self._lock: + for i, frame in zip(miss_indices, decoded, strict=True): + out[i] = frame + key = (record.episode_index, target, round(float(timestamps[i]), 6)) + if len(self._cache) >= self.cache_size: + self._cache.pop(next(iter(self._cache))) + self._cache[key] = frame + # filter out any None left over from decode failures + return [frame for frame in out if frame is not None] + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + """Return up to ``max_frames`` frames uniformly sampled across the episode. + + The whole episode duration is covered; the model picks subtask + boundaries from the temporal pooling it does internally. Frames are + ``torch.Tensor`` (see :meth:`frames_at`). + """ + target = camera_key if camera_key is not None else self.camera_key + if max_frames <= 0 or target is None or not record.frame_timestamps: + return [] + n_frames = min(max_frames, len(record.frame_timestamps)) + if n_frames == len(record.frame_timestamps): + timestamps = list(record.frame_timestamps) + else: + t0 = record.frame_timestamps[0] + t_last = record.frame_timestamps[-1] + if t_last <= t0: + timestamps = [float(t0)] * n_frames + else: + step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0 + timestamps = [float(t0 + i * step) for i in range(n_frames)] + return self.frames_at(record, timestamps, camera_key=target) + + def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None: + """Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``. + + Returns ``None`` if the dataset has no video tracks or extraction + failed. Skips re-extract when the cached clip already exists. + Re-encodes to H.264 via + :func:`lerobot.datasets.video_utils.reencode_video` so the resulting + mp4 is decodable by every downstream video processor — stream-copy + would inherit the source codec (often AV1 in modern LeRobot + datasets), which vllm's libav build cannot decode. + """ + if self.camera_key is None: + return None + cache_dir.mkdir(parents=True, exist_ok=True) + out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4" + if out_path.exists() and out_path.stat().st_size > 0: + return out_path + ep = self._meta.episodes[record.episode_index] + from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"]) + to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"]) + src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key) + encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast") + try: + reencode_video( + src, + out_path, + camera_encoder=encoder, + overwrite=True, + start_time_s=from_timestamp, + end_time_s=to_timestamp, + ) + except Exception: + logger.warning( + "clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True + ) + return None + return out_path if out_path.exists() and out_path.stat().st_size > 0 else None + + def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]: + """Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors. + + Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames` + (torchcodec when available, PyAV otherwise; ``video_backend`` pins + one explicitly). Returns one frame per requested timestamp, or ``[]`` + if decoding failed — callers treat ``[]`` as "no frames available". + """ + ep = self._meta.episodes[episode_index] + from_timestamp = ep[f"videos/{camera_key}/from_timestamp"] + shifted = [from_timestamp + ts for ts in timestamps] + video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) + + try: + # The module phases decode under a ThreadPoolExecutor (see + # ``ExecutorConfig.episode_parallelism``) but torchcodec's cached + # per-file decoder is single-threaded, so serialize decodes on a + # dedicated lock. Frame extraction is a small fraction of episode + # wall time (VLM calls dominate), so the contention is cheap. + with self._decode_lock: + # Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp. + decoded = decode_video_frames( + video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True + ) + return list(decoded) + except Exception as exc: + # Log loudly the first time so a silent vqa-module no-op (every + # prompt skipped because frames_at returned []) is debuggable from + # the job log instead of post-hoc parquet inspection. Subsequent + # failures stay quiet. + with self._lock: + already_warned = self._warned_decode_fail + if not already_warned: + self._warned_decode_fail = True + if not already_warned: + logger.warning( + "VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s", + episode_index, + camera_key, + video_path, + self.video_backend, + exc, + exc_info=exc, + ) + return [] + + +def make_frame_provider( + root: Path, camera_key: str | None = None, video_backend: str | None = None +) -> FrameProvider: + """Build a :class:`VideoFrameProvider` if videos are present, else null.""" + try: + provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend) + except Exception: + return null_provider() + if provider.camera_key is None: + return null_provider() + return provider + + +def _frame_to_pil(frame: Any) -> Any: + """Materialise a decoded frame as a ``PIL.Image`` for the VLM message. + + Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8, + straight from :func:`decode_video_frames`); PIL is only created here, at + the VLM-message boundary, because the chat backends expect PIL images / + data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched. + """ + if not isinstance(frame, torch.Tensor): + return frame + array = frame.detach().cpu() + if array.ndim == 3 and array.shape[0] in (1, 3): + array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C) + if array.shape[-1] == 1: + array = array.squeeze(-1) + return PIL.Image.fromarray(array.to(torch.uint8).numpy()) + + +def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]: + """Convert decoded frames to Qwen-VL-compatible image content blocks.""" + return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames] + + +def to_video_block(frames: list[Any]) -> list[dict[str, Any]]: + """Wrap a list of decoded frames as one Qwen-VL video block. + + Returns ``[]`` when the list is empty, so the caller can splat the result + into a content array without a separate emptiness check. + """ + if not frames: + return [] + return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}] + + +def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]: + """Wrap a video file URL as one ``video_url`` block. + + Used by the ``openai`` backend (transformers serve / vllm serve / + ktransformers serve), where the server handles frame sampling. + Returns ``[]`` when ``url`` is ``None`` so the caller can splat. + """ + if not url: + return [] + return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}] + + +def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image: + """Burn ``timestamp`` (seconds) into the top-left corner of ``image``. + + A solid black badge with white text, so a VLM reading a contact sheet can + cite the exact source time of each tile (e.g. ``012.50s``) directly, + instead of the caller having to map tile position back to time. Mirrors + the macrodata/refiner contact-sheet convention. + """ + from PIL import ImageDraw, ImageFont + + result = image.copy() + draw = ImageDraw.Draw(result) + font = ImageFont.load_default() + label = f"{timestamp:06.2f}s" + left, top, right, bottom = draw.textbbox((0, 0), label, font=font) + text_w, text_h = right - left, bottom - top + pad = max(3, round(min(image.width, image.height) * 0.018)) + draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0)) + draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font) + return result + + +def to_contact_sheet_blocks( + frames: Sequence[Any], + timestamps: Sequence[float], + *, + columns: int = 5, + frames_per_sheet: int = 20, + frame_width: int = 224, + quality: int = 84, +) -> list[dict[str, Any]]: + """Pack decoded frames into timestamped JPEG contact-sheet image blocks. + + Each frame is resized to ``frame_width`` wide, stamped with its + episode-relative timestamp, and tiled row-major into grids of + ``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}`` + block is returned per grid; many frames collapse into a few images, so a + long episode's temporal coverage stays dense at a fraction of the vision + tokens N separate frames would cost. ``frames`` and ``timestamps`` must be + aligned and equal length. Returns ``[]`` for empty input. + """ + from PIL import Image + + if not frames: + return [] + columns = max(1, columns) + frames_per_sheet = max(1, frames_per_sheet) + rows_per_sheet = math.ceil(frames_per_sheet / columns) + + tiles: list[PIL.Image.Image] = [] + for ts, frame in zip(timestamps, frames, strict=False): + img = _frame_to_pil(frame) + if not isinstance(img, PIL.Image.Image): + continue + img = img.convert("RGB") + if img.width != frame_width: + height = max(1, round(img.height * frame_width / img.width)) + img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR) + tiles.append(_draw_timestamp_badge(img, float(ts))) + if not tiles: + return [] + + blocks: list[dict[str, Any]] = [] + for start in range(0, len(tiles), frames_per_sheet): + chunk = tiles[start : start + frames_per_sheet] + cell_w = max(tile.width for tile in chunk) + cell_h = max(tile.height for tile in chunk) + sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0)) + for i, tile in enumerate(chunk): + x = (i % columns) * cell_w + y = (i // columns) * cell_h + sheet.paste(tile, (x, y)) + # JPEG round-trip at ``quality`` to match the refiner convention and + # shrink the wire payload; vision-token count is set by resolution, so + # the real saving is the grid packing, not the codec. + buf = io.BytesIO() + sheet.save(buf, format="JPEG", quality=quality) + buf.seek(0) + blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")}) + return blocks diff --git a/src/lerobot/annotations/steerable_pipeline/modules/__init__.py b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py new file mode 100644 index 000000000..e9ff8ed23 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .general_vqa import GeneralVqaModule +from .interjections_and_speech import InterjectionsAndSpeechModule +from .plan_subtasks_memory import PlanSubtasksMemoryModule + +__all__ = [ + "GeneralVqaModule", + "InterjectionsAndSpeechModule", + "PlanSubtasksMemoryModule", +] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py new file mode 100644 index 000000000..cdc87b579 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``vqa`` module: general VQA at a timed cadence. + +Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K`` +consecutive frames, and every anchored frame gets its own VQA pair. Each +pair is grounded on that single anchor frame — there is no per-pair frame +window. For datasets with multiple cameras, every anchored frame produces +one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is +generated against that camera's frame and stamped with the matching +``camera`` field on the emitted rows. The resolver disambiguates via +``camera=...``; recipes that consume VQA do so through one sub-recipe +per camera (see ``recipes/pi05_hirobot.yaml``). + +Within a single (frame, camera) we still emit at most one ``(vqa, user)`` +and one ``(vqa, assistant)`` row, so the resolver contract stays scalar. + +Question types covered (per the plan's ``vqa`` table): bbox, keypoint, +count, attribute, spatial. The assistant's ``content`` is a JSON string +whose schema depends on the question type. Malformed JSON triggers one +retry inside :meth:`VlmClient.generate_json`. +""" + +from __future__ import annotations + +import json +import logging +import random +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import VqaConfig +from ..frames import FrameProvider, null_provider, to_image_blocks +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord +from ..staging import EpisodeStaging +from ..validator import classify_vqa_answer +from ..vlm_client import VlmClient + + +def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]: + """Return the relative frame indices to anchor VQA emissions to. + + For each emission tick (every ``1/hz`` seconds), we anchor ``k`` + consecutive frames starting at the tick. Ticks fall on the nearest + available source frame timestamp. + """ + if hz <= 0 or k <= 0 or not frame_timestamps: + return [] + t0 = frame_timestamps[0] + t_last = frame_timestamps[-1] + period = 1.0 / hz + indices: list[int] = [] + t = t0 + while t <= t_last + 1e-9: + # find the index of the nearest frame to t + nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t)) + for offset in range(k): + j = nearest_i + offset + if j >= len(frame_timestamps): + break + if not indices or indices[-1] != j: + indices.append(j) + t += period + # dedupe while preserving order + seen: set[int] = set() + deduped: list[int] = [] + for i in indices: + if i in seen: + continue + seen.add(i) + deduped.append(i) + return deduped + + +@dataclass +class GeneralVqaModule: + """Emit grounded VQA pairs at a timed cadence.""" + + vlm: VlmClient + config: VqaConfig + seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) + _warned_no_camera: bool = field(default=False, init=False, repr=False) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + if not record.frame_timestamps: + staging.write("vqa", []) + return + rng = random.Random(f"{self.seed}:{record.episode_index}:vqa") + anchor_idx = _emission_anchor_indices( + record.frame_timestamps, self.config.vqa_emission_hz, self.config.K + ) + cameras = self._target_cameras() + if not cameras: + # No camera available — emit nothing rather than producing + # untagged rows that would fail validation. Surface a loud one- + # time warning so this is never silently a no-op. + if not self._warned_no_camera: + logging.getLogger(__name__).warning( + "vqa module found no cameras on the frame provider — " + "every episode will emit zero VQA rows. Check that the " + "dataset declares observation.images.* features in " + "meta/info.json; passing --vlm.camera_key= at the " + "CLI now also seeds the cameras list as a fallback." + ) + self._warned_no_camera = True + staging.write("vqa", []) + return + + # Build all messages first (one per (frame, camera)), then issue them + # as a single batched generate_json call so the client can fan them + # out concurrently. + per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = [] + for idx in anchor_idx: + ts = float(record.frame_timestamps[idx]) + qtype = rng.choice(self.config.question_types) + for camera in cameras: + messages = self._build_messages(record, qtype, ts, camera) + # Skip cameras that decoded to zero frames at this ts: no point + # asking the VLM to ground a bbox without an image. + if not _has_image_block(messages): + continue + per_call.append((ts, camera, qtype, messages)) + + if not per_call: + staging.write("vqa", []) + return + + results = self.vlm.generate_json([m for _, _, _, m in per_call]) + + rows: list[dict[str, Any]] = [] + for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True): + qa = self._postprocess(result) + if qa is None: + continue + question, answer = qa + rows.append( + { + "role": "user", + "content": question, + "style": "vqa", + "timestamp": ts, + "camera": camera, + "tool_calls": None, + } + ) + rows.append( + { + "role": "assistant", + "content": json.dumps(answer, sort_keys=True), + "style": "vqa", + "timestamp": ts, + "camera": camera, + "tool_calls": None, + } + ) + staging.write("vqa", rows) + + def _target_cameras(self) -> list[str]: + """Return the cameras the ``vqa`` module should iterate per anchored frame. + + Defaults to every camera the provider exposes. Datasets with no + cameras (or test/null providers) yield an empty list, which makes + ``run_episode`` a no-op. + + When ``config.restrict_to_default_camera`` is set, VQA grounds on + only the provider's default camera (the single ``--vlm.camera_key`` + stream), matching the plan / interjection modules so the whole + pipeline focuses on one view. + """ + all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or []) + if getattr(self.config, "restrict_to_default_camera", False): + default = getattr(self.frame_provider, "camera_key", None) + if default and default in all_cameras: + return [default] + # ``restrict_to_default_camera`` is set but the configured default + # isn't one the provider exposes. Returning it anyway would make + # ``_decode`` raise a KeyError deep in frame extraction, so warn and + # fall through to every available camera instead. + if default: + logging.getLogger(__name__).warning( + "restrict_to_default_camera is set but camera_key=%r is not in the " + "provider's cameras %s; grounding VQA on all available cameras instead.", + default, + all_cameras, + ) + return all_cameras + + def _build_messages( + self, + record: EpisodeRecord, + question_type: str, + frame_timestamp: float, + camera_key: str, + ) -> list[dict[str, Any]]: + prompt = load_prompt("vqa").format( + episode_task=record.episode_task, + question_type=question_type, + ) + images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + return [{"role": "user", "content": content}] + + def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None: + if not isinstance(result, dict): + return None + question = result.get("question") + answer = result.get("answer") + if not isinstance(question, str) or not question.strip(): + return None + if not isinstance(answer, dict): + return None + # The validator will enforce shape; here we just sanity-check that the + # answer matches *some* known shape so we can drop garbage early. + if classify_vqa_answer(answer) is None: + return None + return question.strip(), answer + + +def _has_image_block(messages: list[dict[str, Any]]) -> bool: + """Return True if any user content block is a populated image block.""" + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for block in content: + if isinstance(block, dict) and block.get("type") == "image": + return True + return False diff --git a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py new file mode 100644 index 000000000..616f9ce1b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms). + +Two sub-passes: + +1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the + canonical task). No interjection row — the canonical task is already the + user utterance from ``meta/tasks.parquet``. + +2. For mid-episode interruptions, emit a co-timestamped pair: + {role:user, style:interjection, content:} + speech atom (role:assistant, style:None, tool_calls=[say(...)]) + Both rows go in ``language_events`` at the same timestamp. + +The ``plan`` module's :meth:`run_plan_updates` reuses this module's +interjection timestamps to refresh the ``plan`` row at the same instant. +""" + +from __future__ import annotations + +import random +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import InterjectionsConfig +from ..frames import FrameProvider, null_provider, to_image_blocks +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient +from ..writer import speech_atom + + +@dataclass +class InterjectionsAndSpeechModule: + """Generate task-start speech and mid-episode interjection/speech pairs.""" + + vlm: VlmClient + config: InterjectionsConfig + seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + if record.frame_timestamps: + t0 = float(record.frame_timestamps[0]) + initial = self._initial_speech(record) + if initial: + rows.append(speech_atom(t0, initial)) + # Pull the ``plan`` module's subtask spans for this episode so the + # interjection prompt can ground itself in the actual current + # subtask at each chosen timestamp. The ``plan`` module ran first. + episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None + subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t) + rows.extend(self._mid_episode_interjections(record, subtask_spans)) + staging.write("interjections", rows) + + @staticmethod + def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None: + current: str | None = None + for span in spans: + if float(span["start"]) <= t: + current = span.get("text") + else: + break + return current + + def _initial_speech(self, record: EpisodeRecord) -> str | None: + prompt = load_prompt("interjections_initial_speech").format( + episode_task=record.episode_task, + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict) and isinstance(result.get("text"), str): + text = result["text"].strip() + if text: + return text + return None + + def _mid_episode_interjections( + self, + record: EpisodeRecord, + subtask_spans: Sequence[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Generate interjections aligned with the actual demo trajectory. + + Teleop data is frozen — the robot already executed every step in + the video. A *counterfactual* interjection like "actually skip + the wipe" contradicts what then happens in the video, which is + what qwen36moe-10/11 surfaced as low-quality interjections. + + Instead, anchor every interjection at a subtask boundary and + write it as a natural user request for the *upcoming* subtask. + The robot's visible next behavior IS the interjection's effect, + so the training signal stays consistent: interjection text → + plan refresh → action stream all line up. + """ + if self.config.max_interjections_per_episode <= 0: + return [] + if len(subtask_spans) < 2: + # Need at least one transition (subtask 0 → subtask 1). + return [] + # Deterministic per-episode RNG so reruns are stable across SLURM jobs. + rng = random.Random(f"{self.seed}:{record.episode_index}:interjection") + + # Boundaries: the start time of every subtask except the first + # (which is just t0 and is covered by the initial-task speech atom). + boundaries: list[tuple[float, str, str]] = [] + for i in range(1, len(subtask_spans)): + ts = float(subtask_spans[i]["start"]) + if ts < self.config.interjection_min_t: + continue + prev_text = (subtask_spans[i - 1].get("text") or "").strip() + next_text = (subtask_spans[i].get("text") or "").strip() + if not next_text: + continue + boundaries.append((ts, prev_text, next_text)) + if not boundaries: + return [] + + n = min(self.config.max_interjections_per_episode, len(boundaries)) + chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0]) + + out: list[dict[str, Any]] = [] + for t, prev_subtask, next_subtask in chosen: + t_snap = snap_to_frame(t, record.frame_timestamps) + # Window straddles the boundary so the VLM sees the end of the + # previous subtask and the start of the next one — same + # conditioning the policy will see at training time. + window_ts = self._window_timestamps(t_snap, record.frame_timestamps) + prompt = load_prompt("interjections_interjection").format( + episode_task=record.episode_task, + prev_subtask=prev_subtask or "(starting from initial state)", + next_subtask=next_subtask, + timestamp=t_snap, + window_seconds=self.config.interjection_window_seconds, + ) + images = self.frame_provider.frames_at(record, window_ts) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + messages = [{"role": "user", "content": content}] + result = self.vlm.generate_json([messages])[0] + if not isinstance(result, dict): + continue + interjection_text = result.get("interjection") + speech_text = result.get("speech") + if not isinstance(interjection_text, str) or not interjection_text.strip(): + continue + if not isinstance(speech_text, str) or not speech_text.strip(): + continue + out.append( + { + "role": "user", + "content": interjection_text.strip(), + "style": "interjection", + "timestamp": t_snap, + "tool_calls": None, + } + ) + out.append(speech_atom(t_snap, speech_text.strip())) + return out + + def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]: + """Return a small set of frame timestamps centered on ``t_anchor``. + + The window straddles the subtask boundary the interjection sits + on: roughly half the frames cover the end of the previous + subtask, half cover the start of the next one. The VLM therefore + sees BOTH what just finished AND what's about to start, which is + the conditioning we need to write a natural "now please do X" + request that matches the visible upcoming behavior. + """ + if not frame_timestamps: + return [t_anchor] + n = max(1, int(self.config.interjection_window_frames)) + if n == 1: + return [t_anchor] + window = float(self.config.interjection_window_seconds) + step = window / max(1, n - 1) + # Center the window on the anchor so half lands before, half after. + start_offset = -window / 2.0 + targets = [t_anchor + start_offset + step * i for i in range(n)] + first_ts = float(frame_timestamps[0]) + last_ts = float(frame_timestamps[-1]) + snapped: list[float] = [] + seen: set[float] = set() + for tgt in targets: + clamped = min(last_ts, max(first_ts, tgt)) + t = snap_to_frame(clamped, frame_timestamps) + if t not in seen: + seen.add(t) + snapped.append(t) + return snapped or [t_anchor] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py new file mode 100644 index 000000000..b6df6551c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -0,0 +1,780 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles).""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import PlanConfig +from ..frames import ( + FrameProvider, + null_provider, + to_contact_sheet_blocks, +) +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient + +logger = logging.getLogger(__name__) + + +# Prepended to every describe / segment prompt so the VLM knows the images are +# timestamped contact-sheet grids, not a single video, and reads the burned-in +# per-tile timestamp when choosing boundaries. +def _contact_sheet_preamble(columns: int) -> str: + return ( + "CONTACT SHEETS — how to read the images below:\n" + f"- Each image is a grid of sampled video frames, {columns} per row, " + "with time running left-to-right then top-to-bottom (row-major).\n" + "- Each frame has its timestamp burned into the top-left corner, e.g. " + '"012.50s". Use that printed timestamp (not the tile position) when you ' + "choose start/end times; boundaries should land on or near a printed " + "timestamp.\n" + "- Frames continue across grids: an action may span the end of one sheet " + "and the start of the next, so do not place a boundary just because a new " + "image begins.\n\n" + ) + + +# Appended to every describe (and segment) prompt. A visual, causal definition +# of where one event ends and the next begins — adapted from macrodata/refiner — +# to sharpen cut points while the existing prompt keeps owning the imperative +# phrasing. +_CAUSAL_BOUNDARY_RULES = ( + "EVENT BOUNDARIES — where one event ends and the next begins:\n" + "- Start a new event whenever the world state changes: an object becomes " + "held (the gripper closes on it), an object is released (the gripper opens " + "and it stays put), an object reaches a new location, a lid/door/drawer " + "changes open/closed state, a tool starts or stops affecting a surface, or " + "contents visibly move (e.g. poured).\n" + "- If a single action changes the same state gradually and continuously, " + "keep it as ONE event — do not split it.\n" + "- If the same action repeats on different objects or target locations, " + "treat each repetition as a separate event.\n" + "- Do NOT create boundaries for idle time, camera motion, hesitation, or " + "tiny hand adjustments." +) + + +@dataclass +class PlanSubtasksMemoryModule: + """Generate subtask spans, plan, and memory rows. + + All output is persistent (lives in ``language_persistent``): + + - ``subtask`` rows: one per span, stamped at the span's *start* timestamp + (snapped to an exact frame). + - ``plan`` rows: emitted at ``t=0``; refreshed at every interjection + timestamp via :meth:`run_plan_updates` (called by the executor after + the ``interjections`` module completes). + - ``memory`` rows: emitted at each subtask boundary (= subtask start + timestamp from the second subtask onward). + """ + + vlm: VlmClient + config: PlanConfig + frame_provider: FrameProvider = field(default_factory=null_provider) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + # Task driving every plan-module prompt: canonical episode_task, or a + # video-derived one when it's empty/placeholder (see derive_task_*). + effective_task = self._resolve_effective_task(record) + # task_aug rows at t=0: phrasings the renderer rotates ${task} through. + # Either the structured 5-axis taxonomy (task_aug_axes.enabled) or + # free-form n_task_rephrasings; the effective task is always emitted + # first so the rotation covers the source-of-truth phrasing. + t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0 + variants: list[str] | None = None + if self.config.task_aug_axes.enabled and effective_task: + variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes) + elif self.config.n_task_rephrasings > 0 and effective_task: + variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings) + if variants is not None: + rows.extend(self._task_aug_rows([effective_task, *variants], t0)) + + subtask_spans = self._generate_subtasks(record, task=effective_task) + + # subtask rows + for span in subtask_spans: + rows.append( + { + "role": "assistant", + "content": span["text"], + "style": "subtask", + "timestamp": snap_to_frame(span["start"], record.frame_timestamps), + "tool_calls": None, + } + ) + # Plan rows at every subtask boundary (incl. t=0). The plan is a + # numbered list of still-todo subtasks, so re-emitting at each + # boundary makes it shrink as work progresses — ${plan} at frame t is + # exactly what's left to do. + if self.config.emit_plan: + for span in subtask_spans: + boundary_t = snap_to_frame(span["start"], record.frame_timestamps) + plan_text = self._generate_plan( + record, subtask_spans, refresh_t=boundary_t, task=effective_task + ) + if plan_text is not None: + rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": float(boundary_t), + "tool_calls": None, + } + ) + # memory rows at every subtask boundary except the very first start; + # skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only). + prior_memory = "" + memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else [] + for i, span in memory_boundaries: + completed = subtask_spans[i - 1]["text"] + remaining = [s["text"] for s in subtask_spans[i:]] + mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task) + if mem_text: + ts = snap_to_frame(span["start"], record.frame_timestamps) + rows.append( + { + "role": "assistant", + "content": mem_text, + "style": "memory", + "timestamp": ts, + "tool_calls": None, + } + ) + prior_memory = mem_text + staging.write("plan", rows) + + # ------------------------------------------------------------------ + # Task derivation + rephrasings + # ------------------------------------------------------------------ + + _PLACEHOLDER_TASKS: frozenset[str] = frozenset( + { + "debug", + "test", + "tbd", + "todo", + "n/a", + "na", + "untitled", + "unnamed", + "default", + "placeholder", + } + ) + + def _resolve_effective_task(self, record: EpisodeRecord) -> str: + """Decide which task string drives the ``plan`` module for this episode. + + Returns the user-supplied ``record.episode_task`` unless + ``derive_task_from_video`` says otherwise (see config docstring). + Falls back gracefully to the canonical task if video derivation + fails. + """ + canonical = (record.episode_task or "").strip() + mode = (self.config.derive_task_from_video or "off").strip().lower() + if mode == "always": + derived = self._derive_task_from_video(record) + return derived or canonical + if mode == "if_short" and self._task_seems_bad(canonical): + derived = self._derive_task_from_video(record) + if derived: + return derived + return canonical + + def _task_seems_bad(self, task: str) -> bool: + if not task: + return True + if len(task.split()) < int(self.config.derive_task_min_words): + return True + return task.lower() in self._PLACEHOLDER_TASKS + + @staticmethod + def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]: + """Build deduplicated ``task_aug`` rows (role=user) at ``t0``.""" + seen: set[str] = set() + rows: list[dict[str, Any]] = [] + for phrasing in phrasings: + key = phrasing.strip() + if not key or key in seen: + continue + seen.add(key) + rows.append( + {"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None} + ) + return rows + + # ------------------------------------------------------------------ + # VLM call helpers — every plan-module prompt follows the same shape: + # build messages → single VLM call → pull a named field. + # ------------------------------------------------------------------ + + def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any: + """Run a single VLM call and return ``result[field]`` or ``None``. + + Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)`` + dance every prompt-call site needs. + """ + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict): + return result.get(field) + return None + + @staticmethod + def _text_message(text: str) -> list[dict[str, Any]]: + """One-shot text-only user message wrapped for ``generate_json``.""" + return [{"role": "user", "content": [{"type": "text", "text": text}]}] + + def _video_message( + self, + record: EpisodeRecord, + prompt: str, + window: tuple[float, float] | None = None, + ) -> list[dict[str, Any]]: + """User message combining the (optionally windowed) contact sheets with ``prompt``. + + The prompt is always prefixed with a short explanation of how to read + the timestamped grids, so the model treats them as one ordered + sequence of frames rather than unrelated images. + """ + prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt + content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}] + return [{"role": "user", "content": content}] + + def _derive_task_from_video(self, record: EpisodeRecord) -> str | None: + """Ask the VLM "what is this video about" with no task hint at all.""" + text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task") + return text.strip() if isinstance(text, str) and text.strip() else None + + def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]: + """Generate ``n`` text-only paraphrases of ``base_task``.""" + if n <= 0 or not base_task: + return [] + prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n) + raw = self._vlm_field(self._text_message(prompt), "rephrasings") + if not isinstance(raw, list): + return [] + out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)] + return [s for s in out if s][:n] + + # ------------------------------------------------------------------ + # Structured 5-axis task augmentation (EgoMimic-style taxonomy) + # ------------------------------------------------------------------ + + def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]: + """One VLM call → variants along the 5-axis taxonomy. + + Variants from all axes are flattened into a single list (the + downstream pipeline doesn't need to know about the per-axis + bucketing — every variant becomes a ``task_aug`` row). Order + is preserved for reproducibility: synonym_paraphrase first, + then omit_arm, then omit_orientation, then omit_grasp_method, + then combined_omissions. + """ + if not base_task: + return [] + prompt = load_prompt("plan_task_aug_axes").format( + base_task=base_task, + n_synonym=axes_cfg.synonym_paraphrase, + n_omit_arm=axes_cfg.omit_arm, + n_omit_orientation=axes_cfg.omit_orientation, + n_omit_grasp_method=axes_cfg.omit_grasp_method, + n_combined=axes_cfg.combined_omissions, + ) + result = self.vlm.generate_json([self._text_message(prompt)])[0] + if not isinstance(result, dict): + return [] + ordered_axes = ( + "synonym_paraphrase", + "omit_arm", + "omit_orientation", + "omit_grasp_method", + "combined_omissions", + ) + flat: list[str] = [] + seen: set[str] = set() + for axis in ordered_axes: + entries = result.get(axis) + if not isinstance(entries, list): + continue + for item in entries: + if not isinstance(item, str): + continue + key = item.strip().strip('"').strip("'") + if not key or key in seen: + continue + seen.add(key) + flat.append(key) + return flat + + def _episode_video_block( + self, record: EpisodeRecord, window: tuple[float, float] | None = None + ) -> list[dict[str, Any]]: + """Timestamped contact sheets for the describe / segmentation prompts. + + Always renders the (optionally windowed) episode as contact sheets: + frames sampled at ``frames_per_second`` and packed into timestamped + JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole + episodes that exceed it are windowed upstream in + :meth:`_generate_subtasks` so each call stays within budget while the + full episode keeps its sampling density. + + When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE + (``ts - w0``) to match the window-relative time frame the + segmentation prompt works in (spans are offset back to absolute time + afterwards). + """ + if not record.frame_timestamps: + return [] + if window is not None: + w0, w1 = float(window[0]), float(window[1]) + dur = max(0.0, w1 - w0) + n = max(1, int(round(dur * self.config.frames_per_second)) + 1) + n = min(n, self.config.max_frames_per_prompt) + if n <= 1 or dur <= 0.0: + timestamps = [0.5 * (w0 + w1)] + else: + step = dur / (n - 1) + timestamps = [w0 + i * step for i in range(n)] + frames = self.frame_provider.frames_at(record, timestamps) + rel = [ts - w0 for ts in timestamps[: len(frames)]] + return self._contact_sheet_blocks(frames, rel) + episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] + n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1) + n = min(n, self.config.max_frames_per_prompt) + timestamps = self._uniform_episode_timestamps(record, n) + frames = self.frame_provider.frames_at(record, timestamps) + return self._contact_sheet_blocks(frames, timestamps[: len(frames)]) + + @staticmethod + def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]: + """``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly.""" + ts = record.frame_timestamps + if n >= len(ts): + return [float(t) for t in ts] + t0, t_last = float(ts[0]), float(ts[-1]) + if t_last <= t0 or n <= 1: + return [t0] * max(1, n) + step = (t_last - t0) / (n - 1) + return [t0 + i * step for i in range(n)] + + def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]: + """Build timestamped contact-sheet image blocks from decoded frames.""" + return to_contact_sheet_blocks( + frames, + timestamps, + columns=self.config.contact_sheet_columns, + frames_per_sheet=self.config.contact_sheet_frames_per_sheet, + frame_width=self.config.contact_sheet_frame_width, + quality=self.config.contact_sheet_quality, + ) + + def run_plan_updates( + self, + record: EpisodeRecord, + staging: EpisodeStaging, + interjection_times: Sequence[float], + interjection_texts: Sequence[str] | None = None, + ) -> None: + """Append additional ``plan`` rows at every interjection timestamp. + + Plans refresh ONLY on user interjections (event-driven). The + interjection text is forwarded into the prompt so the refreshed plan + reflects the user's correction. + """ + if not self.config.emit_plan: + return + existing = staging.read("plan") + # Pass the last frame timestamp so the final span is closed (else its + # end == start, zero duration, and a refresh inside it is missed). + episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None + spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t) + already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"} + new_rows = list(existing) + + texts: list[str | None] = ( + [None] * len(interjection_times) + if interjection_texts is None + else [str(t) if t else None for t in interjection_texts] + ) + for raw_t, inter_text in zip(interjection_times, texts, strict=True): + t = snap_to_frame(raw_t, record.frame_timestamps) + if t in already_planned: + continue + already_planned.add(t) + plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text) + if plan_text is not None: + new_rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": t, + "tool_calls": None, + } + ) + staging.write("plan", new_rows) + + def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]: + """Generate subtask spans, optionally via a multi-call quality chain. + + Single call (default): watch video → emit subtask JSON. + + Multi-call (opt-in, higher quality, more VLM calls): + 1. ``subtask_describe_first`` — a grounding pass that narrates + ONLY what is visible (no JSON commitment to subtasks yet); + its description is injected into the segmentation prompt so + the model segments its own grounded observations instead of + pattern-matching the task text. + 2. segmentation — emit subtask JSON (as before). + """ + if record.row_count == 0 or not record.frame_timestamps: + return [] + episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] + effective_task = task if task is not None else record.episode_task + + # ---- Auto-windowing (keeps the full sampling density) -------- + # Contact sheets are cheap, but a whole long episode sampled at + # ``frames_per_second`` can still exceed ``max_frames_per_prompt``. + # When it does, split into consecutive windows of exactly that many + # frames (one describe→segment call each, still at the full sampling + # density), then merge + stitch — so an episode of any length is + # covered at full density rather than subsampled into one sparse call. + fps = max(1e-6, float(self.config.frames_per_second)) + n_whole = int(round(episode_duration * fps)) + 1 + if n_whole > self.config.max_frames_per_prompt: + window_s = self.config.max_frames_per_prompt / fps + return self._generate_subtasks_windowed(record, effective_task, window_s) + + # ---- Pass 1 (optional): grounding description ---------------- + observation_block = "" + if getattr(self.config, "subtask_describe_first", False): + description = self._describe_episode(record, effective_task) + if description: + observation_block = ( + "You watched this video and described, chronologically, " + "ONLY what the robot actually does:\n" + f'"""{description}"""\n\n' + "Segment THAT grounded description (cross-checked against " + "the video) into atomic subtasks. Do not introduce any " + "action that is not in your description above.\n\n" + ) + + # ---- Pass 2: segmentation ------------------------------------ + prompt = self._with_causal_rules( + load_prompt("plan_subtasks").format( + episode_task=effective_task, + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{episode_duration:.3f}", + observation_block=observation_block, + ) + ) + spans = self._vlm_field(self._video_message(record, prompt), "subtasks") + cleaned = self._clean_spans(spans, record) + if not cleaned: + return [] + + # ---- Full-episode coverage stitch ---------------------------- + # The VLM can start after t0 or leave gaps, so frames fall through + # with no active subtask. Always stitch into a contiguous + # [t0, t_last] cover. + cleaned = self._stitch_full_coverage(cleaned, record) + + return cleaned + + def _generate_subtasks_windowed( + self, record: EpisodeRecord, task: str, window_s: float + ) -> list[dict[str, Any]]: + """Subtask generation in fixed-length windows at constant fps. + + Splits ``[t0, t_last]`` into consecutive windows of ``window_s`` + seconds, runs the describe -> segment chain on each window's own + frames (sampled at ``frames_per_second``), offsets + each window's spans back to absolute episode time, then merges + + stitches into a contiguous whole-episode cover. + """ + t0 = float(record.frame_timestamps[0]) + t_last = float(record.frame_timestamps[-1]) + all_spans: list[dict[str, Any]] = [] + w0 = t0 + n_windows = 0 + while w0 < t_last - 1e-6: + w1 = min(w0 + window_s, t_last) + all_spans.extend(self._subtasks_for_window(record, task, w0, w1)) + n_windows += 1 + w0 = w1 + logger.info( + "episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans", + record.episode_index, + n_windows, + window_s, + len(all_spans), + ) + # Merge across windows: clamp to the absolute episode, sort, and + # frame-snap to distinct starts (handles any boundary collisions). + cleaned = self._clean_spans(all_spans, record) + if not cleaned: + return [] + return self._stitch_full_coverage(cleaned, record) + + def _subtasks_for_window( + self, record: EpisodeRecord, task: str, w0: float, w1: float + ) -> list[dict[str, Any]]: + """Run describe -> segment on one ``[w0, w1]`` window. + + The model works in window-RELATIVE time ``[0, L]`` (it perceives + the window as a clip starting at 0); spans are offset back to + absolute ``[w0, w1]`` before returning. + """ + window = (w0, w1) + win_len = max(0.0, w1 - w0) + + observation_block = "" + if getattr(self.config, "subtask_describe_first", False): + description = self._describe_episode(record, task, window=window) + if description: + observation_block = ( + "You watched this video clip and described, chronologically, " + "ONLY what the robot actually does:\n" + f'"""{description}"""\n\n' + "Segment THAT grounded description (cross-checked against " + "the clip) into atomic subtasks. Do not introduce any " + "action that is not in your description above.\n\n" + ) + + prompt = self._with_causal_rules( + load_prompt("plan_subtasks").format( + episode_task=task, + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{win_len:.3f}", + observation_block=observation_block, + ) + ) + spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks") + # Window-relative clamp; no frame-snap dedupe yet (done on the + # merged absolute set). + cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False) + if not cleaned: + return [] + + # Offset window-relative spans back to absolute episode time. + for s in cleaned: + s["start"] = w0 + float(s["start"]) + s["end"] = w0 + float(s["end"]) + return cleaned + + def _stitch_full_coverage( + self, spans: list[dict[str, Any]], record: EpisodeRecord + ) -> list[dict[str, Any]]: + """Make subtask spans tile the full episode with no gaps. + + * The first subtask starts at the episode's first frame ``t0`` + (any idle / approach before the first labelled action is folded + into it), so every early frame has an active subtask. + * Each subtask's ``end`` is snapped to the next subtask's + ``start`` (gaps between spans are closed), and the final + subtask's ``end`` extends to the last frame ``t_last``. + + Starts are otherwise left as the (already frame-snapped, distinct) + values the VLM produced — only the FIRST start is pulled + back to ``t0``, which can't collide with a later span because it + was already the earliest. Purely deterministic; runs after the + VLM passes. + """ + if not spans or not record.frame_timestamps: + return spans + t0 = float(record.frame_timestamps[0]) + t_last = float(record.frame_timestamps[-1]) + spans = sorted(spans, key=lambda s: float(s["start"])) + spans[0]["start"] = t0 + for i in range(len(spans) - 1): + spans[i]["end"] = float(spans[i + 1]["start"]) + spans[-1]["end"] = t_last + for s in spans: + if float(s["end"]) < float(s["start"]): + s["end"] = float(s["start"]) + return spans + + @staticmethod + def _with_causal_rules(prompt: str) -> str: + """Append the causal event-boundary rules to a describe/segment prompt.""" + return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}" + + def _clean_spans( + self, + spans: Any, + record: EpisodeRecord, + bounds: tuple[float, float] | None = None, + dedupe: bool = True, + ) -> list[dict[str, Any]]: + """Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows. + + ``bounds`` overrides the clamp range — pass the window's + ``(w_lo, w_hi)`` when cleaning window-relative spans, or leave + ``None`` to clamp to the whole episode ``[t0, t_last]``. + ``dedupe`` runs the frame-snap distinct-start step; skip it for + window-relative spans (frame snapping is done once on the merged, + absolute-time set). + """ + if not spans: + return [] + if bounds is not None: + lo, hi = float(bounds[0]), float(bounds[1]) + else: + lo = record.frame_timestamps[0] + hi = record.frame_timestamps[-1] + cleaned: list[dict[str, Any]] = [] + for span in spans: + try: + start = float(span["start"]) + end = float(span["end"]) + text = str(span["text"]).strip() + except (KeyError, ValueError, TypeError): + continue + start = max(lo, min(start, hi)) + end = max(lo, min(end, hi)) + if end < start: + start, end = end, start + if not text: + continue + cleaned.append({"text": text, "start": start, "end": end}) + cleaned.sort(key=lambda s: s["start"]) + if dedupe: + return self._dedupe_starts_to_distinct_frames(cleaned, record) + return cleaned + + def _describe_episode( + self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None + ) -> str: + """Grounding pass: free-form chronological description of the (windowed) video.""" + prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task)) + text = self._vlm_field(self._video_message(record, prompt, window=window), "description") + return text.strip() if isinstance(text, str) and text.strip() else "" + + @staticmethod + def _dedupe_starts_to_distinct_frames( + spans: list[dict[str, Any]], record: EpisodeRecord + ) -> list[dict[str, Any]]: + """Bump same-frame subtask starts onto distinct frames. + + Two consecutive VLM spans whose ``start`` rounds to the same + source frame (after :func:`snap_to_frame`) would otherwise emit + two ``style=subtask`` rows at the identical persistent + timestamp. The training-time renderer's ``active_at(t, + style=subtask)`` resolver can't disambiguate that and raises + ``Ambiguous resolver for style='subtask'``. + + Walk the (sorted-by-start) spans, snap each to its frame, and + if the snapped frame is already taken push the span onto the + next unused frame so both subtasks survive on distinct + timestamps. If the episode ends before a free frame is found, + the trailing span is dropped with a warning — better than + poisoning the render. + """ + if not spans: + return spans + frames = record.frame_timestamps + if not frames: + return spans + used: set[float] = set() + out: list[dict[str, Any]] = [] + for span in spans: + ts = snap_to_frame(span["start"], frames) + if ts in used: + next_ts = next((f for f in frames if f > ts and f not in used), None) + if next_ts is None: + logger.warning( + "episode %d: subtask %r snapped to occupied frame " + "%.3f and no free later frame exists — dropping", + record.episode_index, + span.get("text"), + ts, + ) + continue + ts = next_ts + used.add(ts) + new_span = {**span, "start": ts} + if float(new_span.get("end", ts)) < ts: + new_span["end"] = ts + out.append(new_span) + return out + + def _generate_plan( + self, + record: EpisodeRecord, # noqa: ARG002 (kept for signature stability) + subtask_spans: Sequence[dict[str, Any]], + *, + refresh_t: float | None = None, + interjection: str | None = None, # noqa: ARG002 + task: str | None = None, # noqa: ARG002 + ) -> str | None: + """Deterministic plan = numbered list of *still-todo* subtasks. + + No VLM call: a plain numbered list keeps the plan aligned with the + upcoming subtasks (the old VLM "compact hierarchical plan" prompt + cost a round-trip per episode/refresh and could diverge). + + 1. + 2. + + On a refresh at ``refresh_t`` (from ``run_plan_updates`` on + interjections, and ``run_episode`` at each boundary), only subtasks + starting at or after ``refresh_t`` are included — so it always + describes what's left. + """ + if not subtask_spans: + return None + remaining = [ + s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t) + ] + if not remaining: + # Past the last subtask boundary on a late refresh — nothing + # left to plan; emit None so the caller skips the row. + return None + return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)) + + def _generate_memory( + self, + record: EpisodeRecord, + prior_memory: str, + completed: str, + remaining: Sequence[str], + *, + task: str | None = None, + ) -> str: + prompt = load_prompt("plan_memory").format( + episode_task=(task if task is not None else record.episode_task), + prior_memory=prior_memory or "(none)", + completed_subtask=completed, + remaining_subtasks=", ".join(remaining) if remaining else "(none)", + ) + memory = self._vlm_field(self._text_message(prompt), "memory") + return memory.strip() if isinstance(memory, str) else "" diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py new file mode 100644 index 000000000..5ce8e163b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prompt templates loaded as plain text. + +One file per use site. Templates use ``str.format(**vars)`` substitution; we +intentionally avoid jinja2 here so the templates remain inspectable in +plain editors and roundtrip cleanly through ``ruff format``. +""" + +from __future__ import annotations + +from pathlib import Path + +_DIR = Path(__file__).parent + + +def load(name: str) -> str: + """Read prompt template ``name.txt`` from the ``prompts/`` directory.""" + path = _DIR / f"{name}.txt" + return path.read_text(encoding="utf-8") diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt new file mode 100644 index 000000000..625ce920c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt @@ -0,0 +1,12 @@ +The user just asked the robot: "{episode_task}". + +Generate a short verbal acknowledgement the robot would speak back before +beginning the task. Style: compact, confident, friendly. + +Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.", +"OK, starting with the sponge.", "Got it.". + +Prefer very short replies: "Got it.", "On it.", "OK." + +Output strictly valid JSON: + {{ "text": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt new file mode 100644 index 000000000..4a4719f54 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt @@ -0,0 +1,46 @@ +You are generating training data for a Hi Robot-style hierarchical +robot policy. The robot in this demonstration has ALREADY executed +every step shown in the video — we cannot retroactively change the +action stream. To keep training data consistent with the video, the +"interjection" must align with what the robot is *about to do next* in +the demonstration, framed as a natural mid-task user request. + +The episode's overall task: "{episode_task}". + +The images above show roughly {window_seconds:.1f} seconds straddling a +subtask boundary in the demonstration: + +- Subtask the robot just finished: "{prev_subtask}" +- Subtask the robot is about to start: "{next_subtask}" +- Time into episode: {timestamp:.2f}s + +Write ONE compact interjection the user would naturally say at this +moment to prompt / confirm / encourage the robot to do "{next_subtask}". +Keep it like a mid-task coaching cue, not a full instruction paragraph. +Also write the robot's compact verbal acknowledgement. + +Hard rules: + +- The interjection MUST be consistent with the next subtask. The user + cannot ask for something different from what the robot then does in + the video. If you're tempted to say "actually skip X" or "do Y + instead", DO NOT — those would contradict the demonstration. +- The interjection must reference an object, location, or action that + is plausible given the visible scene and the next subtask text. +- One short phrase or sentence each. Conversational, not robotic. +- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}." +- Keep robot speech very short: "OK.", "On it.", "Doing that." + +Style examples (vary the phrasing — don't reuse these verbatim): + - "Now go ahead and {next_subtask}." + - "Great, can you {next_subtask} next?" + - "{next_subtask}, please." + - "Before you continue, please {next_subtask}." + - "Looking good — {next_subtask} now." + - "Okay, {next_subtask}." + +Output strictly valid JSON: + {{ + "interjection": "", + "speech": "" + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt new file mode 100644 index 000000000..b5278368b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt @@ -0,0 +1,36 @@ +You are updating the robot's compressed semantic memory at the boundary of +a completed subtask. + +Reference (verbatim from MEM, Torne 2026): +"Remove or compress information in the language memory whenever +appropriate. Keep ONLY the minimal set of relevant information for future +task execution. Specific object attributes (colors, precise quantities of +each item) get discarded when their details won't affect subsequent +actions. Functional outcomes (where items went, how many) are preserved." + +Episode task: "{episode_task}" +Previous memory: {prior_memory} +Just-completed subtask: "{completed_subtask}" +Remaining subtasks (for relevance judgement only): {remaining_subtasks} + +Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the +robot has accomplished so far — the running story it would tell itself. + +Authoring rules: +- First person, past tense. Every sentence starts with "I": "I picked + up...", "I opened...", "I moved to...". +- One or two short sentences. Extend the previous memory with the + just-completed subtask; do not rewrite it from scratch. +- Keep WHAT happened (functional outcomes — where items went, how many), + drop HOW (grasp details, motions). +- Compress completed steps and drop object attributes (colors, exact + counts) once they no longer affect the remaining subtasks. + +Example (MEM, Torne 2026): + Before: "I prepared the pot and got the potatoes, milk, and butter. I + moved to the drawer." + After: "I prepared the pot and got the ingredients. I opened the + drawer with the masher." + +Output strictly valid JSON: + {{ "memory": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt new file mode 100644 index 000000000..6b709e41d --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt @@ -0,0 +1,27 @@ +You are watching a teleoperated robot demonstration from a single +camera. The user asked the robot to: "{episode_task}" + +This is an OBSERVATION pass. Watch the entire clip and describe, in +chronological order, ONLY what the robot physically does — the concrete +motions, approaches, contacts, grasps, releases, and relocations you can +actually SEE in the frames. + +Hard rules: +- Describe only motion visible in the video. Do NOT use the task + instruction to guess steps that aren't shown. The instruction is the + goal; the video is ground truth. +- Do NOT segment into named subtasks yet and do NOT output JSON beyond + the single field below. Just narrate what happens. +- Give an approximate timestamp (in seconds) for each distinct event, + e.g. "0.0-1.4s: the base drives forward toward the stove". +- Do NOT invent objects, grasps, destinations, or steps. If the robot + only does one thing (e.g. it just navigates and the clip ends), say + exactly that and nothing more. +- Be concrete and literal. "the gripper closes on the mug" — not "the + robot prepares to make coffee". + +Output strictly valid JSON: + + {{ + "description": "" + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt new file mode 100644 index 000000000..e6a5260a7 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt @@ -0,0 +1,112 @@ +You are labeling a teleoperated robot demonstration. + +The user originally asked: "{episode_task}" + +You are shown the entire demonstration as a single video. Watch the +whole clip, then segment it into a list of consecutive atomic subtasks +the robot performs. + +{observation_block}GROUNDING — read this first, it overrides everything below: +- Label ONLY what the robot actually does in the video. Every subtask + you emit must correspond to motion you can SEE in specific frames. +- Do NOT invent, anticipate, or pad. If the robot only does one thing + (e.g. it just navigates to a location and the clip ends), emit + EXACTLY ONE subtask. Many demonstrations are a single atomic skill. +- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer + subtasks than the ceiling is not just allowed, it is expected for + short / atomic demonstrations. One correct subtask is far better + than several invented ones. +- If the video does not clearly show the action implied by the task, + describe what you actually see — do NOT fabricate the task's steps + from the instruction text. The instruction tells you the goal; the + VIDEO is the ground truth for what happened. + +Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts: + +- Each subtask = one COMPOSITE atomic skill the low-level policy can + execute end-to-end. A "skill" bundles its own approach motion with + its terminal action — do NOT split the approach off as its own + subtask. The whole-arm policy already learns to reach as part of + every manipulation primitive. +- Write each subtask as an IMPERATIVE COMMAND, starting with one of + these verbs (extend only when none fits): + pick up — approach + grasp + lift in one subtask + put on/in — transport + release in one subtask + place on/in — synonym of "put"; pick one and stay consistent + push — contact + linear shove + pull — contact + linear retract + turn — rotary actuation + press