mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5753f8c18b | |||
| 97bd373d15 | |||
| 10a73e3c95 | |||
| 27c9288b24 | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec | |||
| c8225d749a | |||
| 68f869b7a0 | |||
| 4119ad4d10 | |||
| 750358895b | |||
| bc4d0db8f4 | |||
| 45e273b806 | |||
| 8b5f56b63c | |||
| 9f1ee224cb | |||
| 885f55ef04 | |||
| bba996ef8d | |||
| 162b07512a | |||
| 0509ea05df | |||
| de1a9e5ad9 | |||
| 6803439f22 | |||
| 90d1e70da2 | |||
| a35ac22afd | |||
| fd7fed08e2 | |||
| 0c3cc4c9d6 | |||
| 6caeac9d07 | |||
| 1d6810b814 | |||
| de9af57475 | |||
| 364750ada2 | |||
| 342d223706 | |||
| e3b203e5a7 | |||
| b568c41355 |
@@ -178,9 +178,3 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -105,7 +105,7 @@ lerobot-train \
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.7](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
|
||||
@@ -45,8 +45,6 @@
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: video_encoding_parameters
|
||||
title: Video encoding parameters
|
||||
- local: streaming_video_encoding
|
||||
@@ -70,7 +68,7 @@
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: NVIDIA GR00T
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` watches each episode's video with a vision-language
|
||||
model (VLM) and writes natural-language annotations back into your
|
||||
dataset. It fills the two language columns from the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — straight into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
In short: point it at a LeRobot dataset, and it adds subtasks, plans,
|
||||
memory, interjections, speech, and visual Q&A that a policy can be
|
||||
trained on.
|
||||
|
||||
## How it fits together
|
||||
|
||||
```text
|
||||
your dataset lerobot-annotate
|
||||
(LeRobot v3.1)
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ read episodes │
|
||||
└──────────────────────────┬──────────────────────────┘
|
||||
│
|
||||
┌────────────────────┼────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL
|
||||
│ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI
|
||||
└────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three
|
||||
└────────────────────┼─────────────────────┘
|
||||
│ each module stages raw JSONL
|
||||
▼ into .annotate_staging/
|
||||
┌─────────────────┐
|
||||
│ validator │ ◀── checks everything
|
||||
└────────┬────────┘
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ writer │
|
||||
└────────┬────────┘
|
||||
▼
|
||||
data/chunk-*/file-*.parquet
|
||||
(+ meta/info.json tools)
|
||||
```
|
||||
|
||||
Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared
|
||||
VLM. Each module stages its output to disk, a validator checks it, and a
|
||||
single writer rewrites the dataset shards in place.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Each module emits a few kinds of annotation ("styles"), routed to one of
|
||||
the two language columns:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | --------------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | `plan` |
|
||||
| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` |
|
||||
| `interjection` | `language_events` | `interjections` |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` |
|
||||
| `vqa` (user / assistant pair) | `language_events` | `vqa` |
|
||||
|
||||
### How subtasks are generated
|
||||
|
||||
The `plan` module doesn't ask the VLM for subtasks in one shot. Instead
|
||||
it uses a two-step **describe → segment** flow:
|
||||
|
||||
1. **Describe** — the VLM narrates only what it actually sees in the
|
||||
chosen camera (no guessing about the task).
|
||||
2. **Segment** — that description is fed back in, and the VLM splits the
|
||||
episode into consecutive atomic subtasks.
|
||||
|
||||
The resulting spans are then stitched into a gap-free, full-episode
|
||||
cover, so **every frame has exactly one active subtask**. See
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
for the production settings (single camera, embedded frames, windowed
|
||||
subtask generation).
|
||||
|
||||
### Tools
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet. The tool
|
||||
catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)).
|
||||
After every run, the pipeline makes sure the canonical `say` schema is in
|
||||
that list, keeping any tools you declared beforehand.
|
||||
|
||||
Want to add your own tool? Edit `meta/info.json["tools"]` directly — the
|
||||
pipeline preserves whatever is already there. That makes the tool visible
|
||||
to the chat template, so the model can learn to _generate_ the call. The
|
||||
runtime layer that actually _executes_ a generated call (the `Tool`
|
||||
protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of
|
||||
this PR — the [Tools](./tools) doc marks those pieces as
|
||||
not-yet-implemented.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs).
|
||||
The repo ships a launcher script you copy and tweak for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
```
|
||||
|
||||
[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py)
|
||||
starts a single-GPU `h200` job (bump it to `h200x4` for big datasets)
|
||||
that:
|
||||
|
||||
1. installs `lerobot` (from `main`) plus the annotation extras,
|
||||
2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and
|
||||
drives it over the OpenAI-compatible API,
|
||||
3. runs the `plan` / `interjections` / `vqa` modules across the dataset
|
||||
with `lerobot-annotate`,
|
||||
4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or
|
||||
back to `--repo_id` in place if you leave that unset).
|
||||
|
||||
To use a different dataset, model, or hub repo, edit the `CMD` block in
|
||||
the script. Every flag there maps directly to a `lerobot-annotate` flag
|
||||
(run `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Key options
|
||||
|
||||
These are the flags you'll reach for most often. Run
|
||||
`lerobot-annotate --help` for everything else; the defaults are tuned for
|
||||
short manipulation episodes.
|
||||
|
||||
### Dataset in / out
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------- | ------- | ----------------------------------------------------------------------- |
|
||||
| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). |
|
||||
| `--root` | — | Annotate a local dataset directory instead. |
|
||||
| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). |
|
||||
| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). |
|
||||
| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). |
|
||||
| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. |
|
||||
|
||||
### Which modules run
|
||||
|
||||
Every module is on by default and can be toggled independently (set to
|
||||
`false` to skip it, e.g. to iterate on one module at a time):
|
||||
|
||||
| Flag | Default | Turns off |
|
||||
| ------------------------- | ------- | ----------------------------------- |
|
||||
| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug |
|
||||
| `--interjections.enabled` | `true` | interjections + speech atoms |
|
||||
| `--vqa.enabled` | `true` | the VQA pairs |
|
||||
|
||||
### The VLM (`--vlm.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- |
|
||||
| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. |
|
||||
| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. |
|
||||
| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). |
|
||||
| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). |
|
||||
| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). |
|
||||
| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. |
|
||||
| `--vlm.max_new_tokens` | `512` | Generation cap per call. |
|
||||
| `--vlm.temperature` | `0.2` | Sampling temperature. |
|
||||
|
||||
### Subtasks / plan / memory (`--plan.*`)
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `--plan.frames_per_second` | `1.0` | How densely the episode video is sampled. |
|
||||
| `--plan.max_video_frames` | `32` | Hard cap on frames per call (context-budget guard — don't exceed ~32 for a 32k context). |
|
||||
| `--plan.subtask_window_seconds` | `0` | Split long episodes into fixed windows for constant frame density (`0` = whole episode). |
|
||||
| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. |
|
||||
| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). |
|
||||
| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). |
|
||||
| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). |
|
||||
| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). |
|
||||
| `--plan.use_video_url` | `false` | Send a server-side video clip instead of embedded frames. |
|
||||
|
||||
### Interjections + VQA
|
||||
|
||||
| Flag | Default | What it does |
|
||||
| ----------------------------------------------- | ------- | ---------------------------------------------------------- |
|
||||
| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. |
|
||||
| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. |
|
||||
| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). |
|
||||
| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. |
|
||||
|
||||
## Contributing new modules
|
||||
|
||||
The pipeline is built to grow, and **contributions are very welcome** —
|
||||
a brand-new module (say, trajectory traces or affordances), a new prompt
|
||||
template, a smarter grounding flow, or quality fixes to the existing
|
||||
`plan` / `interjections` / `vqa` modules.
|
||||
|
||||
Every module lives under
|
||||
`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM
|
||||
client and the keyframe cache, writes its raw output to the staging
|
||||
tree, and plugs into the executor as its own phase. Got an idea? Open an
|
||||
issue or PR on [the repo](https://github.com/huggingface/lerobot).
|
||||
|
||||
## How recipes consume the output
|
||||
|
||||
The annotations are meant to be read by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)). Typically:
|
||||
|
||||
- low-level / high-level / memory-update branches read
|
||||
`subtask` / `plan` / `memory` from `language_persistent`.
|
||||
- an interjection-response branch reads `interjection` events plus the
|
||||
paired speech atom (merged into one assistant turn via `tool_calls_from`)
|
||||
and the matching `plan` refresh at the same timestamp.
|
||||
- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why state and events are split
|
||||
|
||||
Two ideas shape the design:
|
||||
|
||||
1. **Persistent state vs. exact events.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) apply to the whole episode and answer "what's true
|
||||
right now?". Event rows (`interjection`, `vqa`, speech) appear only on
|
||||
the one frame whose timestamp matches. Timestamps are copied straight
|
||||
from the source parquet — never recomputed in floating point.
|
||||
2. **One VLM pass.** All three modules share a single VLM client (the
|
||||
OpenAI-compatible client talking to the job's vLLM server), so you pay
|
||||
for one model load per dataset, not three.
|
||||
|
||||
## Re-running a single module
|
||||
|
||||
Each module stages its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. This makes
|
||||
prompt iteration cheap: re-running one module overwrites only its own
|
||||
JSONL, then the writer recomposes the final parquet. Disable modules you
|
||||
don't want with `--plan.enabled=false` (and likewise
|
||||
`--interjections.enabled` / `--vqa.enabled`) to test one at a time.
|
||||
|
||||
## What the validator checks
|
||||
|
||||
Before the writer runs, `StagingValidator` confirms:
|
||||
|
||||
- every event row lands exactly on a real frame timestamp;
|
||||
- no speech / interjection pairs are left orphaned;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (a warning, not an error);
|
||||
- each VQA assistant `content` is valid JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row goes to the column chosen by `column_for_style(style)`.
|
||||
|
||||
Any error aborts the writer. Pass `--skip_validation=true` to override
|
||||
while debugging.
|
||||
|
||||
## Where each module's ideas come from
|
||||
|
||||
- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
for atom granularity ("pick up one piece of lettuce", "place bowl to
|
||||
box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07))
|
||||
for "how, not what" detail.
|
||||
- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)):
|
||||
keep only the minimal relevant information — preserve outcomes, drop
|
||||
specific attributes.
|
||||
- **`interjections`.** Hi Robot's scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom
|
||||
(`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`).
|
||||
- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for
|
||||
grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies
|
||||
([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction
|
||||
grounding. Pi0.7 also grounds answers across abstraction levels.
|
||||
|
||||
When improving a module, tweak its prompt template in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` rather than
|
||||
rewriting from scratch.
|
||||
|
||||
## Roughly how much it costs
|
||||
|
||||
Per episode, the pipeline makes about `max_steps` plan calls,
|
||||
`max_interjections_per_episode` interjection calls, and
|
||||
`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8
|
||||
subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's
|
||||
~50 VLM calls.
|
||||
|
||||
Storage stays small: `language_persistent` is at most tens of KB per
|
||||
episode (parquet dictionary-encodes the one entry that repeats across
|
||||
frames), and `language_events` is empty on most frames — its size scales
|
||||
with the number of emissions, not `num_frames × num_emissions`.
|
||||
@@ -193,7 +193,7 @@ To learn more about training policies with LeRobot, please refer to the training
|
||||
|
||||
- [SmolVLA](./smolvla)
|
||||
- [Pi0.5](./pi05)
|
||||
- [GR00T N1.5](./groot)
|
||||
- [GR00T N1.7](./groot)
|
||||
|
||||
Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation:
|
||||
|
||||
|
||||
+79
-30
@@ -1,16 +1,19 @@
|
||||
# GR00T N1.5 Policy
|
||||
# GR00T Policy
|
||||
|
||||
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
|
||||
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
|
||||
|
||||
This document outlines the specifics of its integration and usage within the LeRobot framework.
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
|
||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||
|
||||
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
|
||||
|
||||
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
|
||||
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
|
||||
@@ -28,33 +31,46 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
As of today, GR00T N1.5 requires flash attention for it's internal working.
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
|
||||
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
|
||||
|
||||
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
|
||||
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for your system
|
||||
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
|
||||
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
3. Install LeRobot by running:
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
|
||||
```bash
|
||||
pip install lerobot[groot]
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T in your LeRobot configuration, specify the policy type as:
|
||||
To use GR00T N1.7:
|
||||
|
||||
```python
|
||||
policy.type=groot
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -87,21 +103,54 @@ accelerate launch \
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
### LIBERO Benchmark Results
|
||||
|
||||
> [!NOTE]
|
||||
> Follow our instructions for Libero usage: [Libero](./libero)
|
||||
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
|
||||
|
||||
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
|
||||
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
|
||||
|
||||
| Benchmark | LeRobot Implementation | GR00T Reference |
|
||||
| ------------------ | ---------------------- | --------------- |
|
||||
| **Libero Spatial** | 82.0% | 92.0% |
|
||||
| **Libero Object** | 99.0% | 92.0% |
|
||||
| **Libero Long** | 82.0% | 76.0% |
|
||||
| **Average** | 87.0% | 87.0% |
|
||||
### GR00T N1.7 LIBERO Checkpoints
|
||||
|
||||
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
|
||||
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
|
||||
|
||||
| Suite | Checkpoint subdirectory |
|
||||
| -------------- | ----------------------- |
|
||||
| LIBERO Spatial | `libero_spatial` |
|
||||
| LIBERO Object | `libero_object` |
|
||||
| LIBERO Goal | `libero_goal` |
|
||||
| LIBERO 10 | `libero_10` |
|
||||
|
||||
Preliminary LeRobot integration results:
|
||||
|
||||
| Suite | Status | Success rate | n_episodes |
|
||||
| -------------- | ------ | -----------: | ---------: |
|
||||
| LIBERO Spatial | ✓ | ~95% | XX |
|
||||
| LIBERO Object | ✓ | XX% | XX |
|
||||
| LIBERO Goal | ✓ | XX% | XX |
|
||||
| LIBERO 10 | ✓ | XX% | XX |
|
||||
| **Average** | ✓ | **XX%** | **XX** |
|
||||
|
||||
Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.n_episodes=50
|
||||
```
|
||||
|
||||
Use `eval.n_episodes >= 50` per suite when reporting success rates.
|
||||
|
||||
### Evaluate in your hardware setup
|
||||
|
||||
@@ -131,4 +180,4 @@ lerobot-rollout\
|
||||
|
||||
## License
|
||||
|
||||
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
|
||||
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
|
||||
|
||||
@@ -647,6 +647,5 @@ The `--strategy.type` flag selects the execution mode:
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
- `episodic`: Episode-oriented policy recording with reset phases between episodes
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -157,44 +157,6 @@ Foot pedal input is also supported via `--strategy.input_device=pedal`. Configur
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
### Episodic (`--strategy.type=episodic`)
|
||||
|
||||
Episode-oriented recording that mirrors the behavior of `lerobot-record`. The policy drives the robot for each episode; an optional teleoperator can drive the robot during the reset phase between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=episodic \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/my_eval_data \
|
||||
--dataset.num_episodes=20 \
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10 \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
Teleop is optional — if omitted the robot holds its position during the reset phase.
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ----------- | -------------------------------- |
|
||||
| `→` (right) | End the current episode early |
|
||||
| `←` (left) | Discard episode and re-record it |
|
||||
| `ESC` | Stop the recording session |
|
||||
|
||||
| Flag | Description |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------- |
|
||||
| `--dataset.num_episodes` | Number of episodes to record |
|
||||
| `--dataset.episode_time_s` | Duration of each recording episode in seconds |
|
||||
| `--dataset.reset_time_s` | Duration of the reset phase between episodes in seconds |
|
||||
| `--teleop.type` | Optional. Teleoperator to drive the robot during resets |
|
||||
| `--strategy.reset_to_initial_position` | Whether to reset the robot to its initial position between episodes |
|
||||
| `--strategy.smooth_leader_to_follower_handover` | Whether to turn on or off the leader -> follower smooth handover behavior. |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
@@ -141,11 +141,6 @@ sample["target_message_indices"]
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Blends
|
||||
|
||||
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
## Research Paper
|
||||
|
||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||
|
||||
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
|
||||
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||
|
||||
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||
> Current releases support GR00T N1.7 only.
|
||||
|
||||
## Repository
|
||||
|
||||
@@ -24,4 +31,103 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
|
||||
|
||||
Blog: https://developer.nvidia.com/isaac/gr00t
|
||||
|
||||
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
|
||||
Hugging Face Models:
|
||||
|
||||
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO
|
||||
|
||||
## Original-vs-LeRobot parity test
|
||||
|
||||
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||
over every embodiment tag present in the checkpoint:
|
||||
|
||||
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||
or action-dim mismatch fails the test.
|
||||
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||
state normalization, no mocks) must produce the **same collated model inputs**
|
||||
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||
`embodiment_id`) as the original package's processor.
|
||||
|
||||
### Why two environments
|
||||
|
||||
The original `gr00t` package pins `transformers==4.57.3` (Python 3.10); this
|
||||
integration requires `transformers>=5.x` (Qwen3-VL). Under 5.x, `PretrainedConfig`
|
||||
is itself a defaulted dataclass, so the original config dataclasses fail to import
|
||||
(`non-default argument follows default argument`). The two implementations therefore
|
||||
**cannot be imported in the same Python process**.
|
||||
|
||||
So the test uses a **producer / consumer** split across two venvs:
|
||||
|
||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||
(`in::` keys), the seed, and the raw `action_pred`.
|
||||
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||
preprocessor pipeline and asserts the collated tensors match.
|
||||
|
||||
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||
> Re-run the producer to refresh them.
|
||||
|
||||
### Fairness controls
|
||||
|
||||
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||
covered separately by the preprocessor-parity case, which compares its output
|
||||
against those same collated tensors from identical raw observations.
|
||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||
kernel/rounding noise, not an implementation difference.)
|
||||
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||
replays the recorded value.
|
||||
|
||||
### How to run
|
||||
|
||||
```bash
|
||||
# Resolve a local checkpoint (GR00T-N1.7-LIBERO / libero_10)
|
||||
CKPT=$(python - <<'PY'
|
||||
import os
|
||||
from huggingface_hub import snapshot_download
|
||||
print(os.path.join(snapshot_download("nvidia/GR00T-N1.7-LIBERO",
|
||||
allow_patterns=["libero_10/*"]), "libero_10"))
|
||||
PY
|
||||
)
|
||||
|
||||
# 1) Produce the original-side artifacts for all embodiments (original gr00t venv, CUDA)
|
||||
CUDA_VISIBLE_DEVICES=0 /path/to/Isaac-GR00T/.venv-original/bin/python \
|
||||
tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
--ckpt "$CKPT" --out-dir tests/policies/groot/artifacts --device cuda --seed 42
|
||||
|
||||
# 2) Run the parity test (LeRobot venv) — one parametrized case per embodiment
|
||||
CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||
```
|
||||
|
||||
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||
when the checkpoint / artifacts are absent.
|
||||
|
||||
#### Env knobs (all optional)
|
||||
|
||||
| Var | Default | Purpose |
|
||||
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM).
|
||||
|
||||
Spawns one single-GPU ``h200`` job that:
|
||||
|
||||
1. installs ``lerobot`` from ``main`` plus the annotation extras,
|
||||
2. boots one vllm server with Qwen3.6-27B (dense VLM),
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (each episode generates its own subtasks +
|
||||
memory),
|
||||
4. uploads the annotated dataset to ``--new_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your
|
||||
run. For larger datasets, scale to ``h200x4`` and raise
|
||||
``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect "
|
||||
"openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_pretrain_human300_v4 "
|
||||
"--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated5 "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-27B "
|
||||
"--vlm.parallel_servers=1 "
|
||||
"--vlm.num_gpus=1 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.robot0_agentview_right "
|
||||
# Phase 1 — plan module (subtasks + memory).
|
||||
# Embed decoded frames (not a file:// clip): if clip extraction fails,
|
||||
# the video_url path silently sends no video and the VLM hallucinates.
|
||||
"--plan.use_video_url=false "
|
||||
"--plan.frames_per_second=1.0 "
|
||||
# 32 frames ≈ 8-10k vision tokens, fits the 32768 context. Don't push
|
||||
# toward 128 — that overflows the context (BadRequestError 400).
|
||||
"--plan.max_video_frames=32 "
|
||||
# Window long episodes into 32s chunks (constant 1 fps density) so they
|
||||
# get more subtasks; per-window spans are merged + stitched. 0 disables.
|
||||
"--plan.subtask_window_seconds=32 "
|
||||
# RoboCasa: the dataset task string is authoritative (eval uses it), so
|
||||
# keep it driving subtasks. ``always`` would throw it away and hallucinate.
|
||||
"--plan.derive_task_from_video=off "
|
||||
# No task augmentation: eval conditions on the exact task strings, so
|
||||
# rephrasings are unused at best and harmful when they drift.
|
||||
"--plan.n_task_rephrasings=0 "
|
||||
# Keep subtask decomposition tight for atomic tasks.
|
||||
"--plan.plan_max_steps=10 "
|
||||
# Only subtasks + memory — skip the numbered "plan" rows. true re-enables.
|
||||
"--plan.emit_plan=false "
|
||||
# The describe->segment grounding pass (+1 VLM call/episode) is ON by
|
||||
# default; pass --plan.subtask_describe_first=false to skip it.
|
||||
# Phase 2 — interjections + speech.
|
||||
"--interjections.max_interjections_per_episode=6 "
|
||||
# Phase 4 — general VQA: disabled for this run.
|
||||
"--vqa.enabled=false"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
+2
-34
@@ -85,11 +85,6 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -147,7 +142,6 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -203,7 +197,7 @@ wallx = [
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
@@ -229,29 +223,6 @@ vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). The only backend is ``openai``,
|
||||
# which talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs
|
||||
# (see examples/annotations/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
# ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and
|
||||
# uv's single unified lock would then cap ``torch`` for every extra
|
||||
# (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI
|
||||
# break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM;
|
||||
# install it locally only if you run your own ``vllm serve``.
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=1.0.0,<3.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -346,10 +317,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
@@ -367,7 +335,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -1,196 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanConfig:
|
||||
"""``plan`` module: subtasks + plan + memory + task augmentation."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# Derive the task from video instead of episode_task: off / if_short / always.
|
||||
# Affects prompts only; ``meta/tasks.parquet`` is untouched.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frames sampled uniformly, capped at max_video_frames — a hard context cap
|
||||
# (~300 tokens/frame, so 32 fit a 32k VLM; 128 overflow).
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 32
|
||||
|
||||
# >0: split long episodes into windows of this length (constant fps density)
|
||||
# instead of subsampling the whole episode; spans merged + stitched. 0 disables.
|
||||
subtask_window_seconds: float = 0.0
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# Narrate-only grounding pass before segmenting — best defense against subtasks
|
||||
# invented from the task text (+1 VLM call/episode).
|
||||
subtask_describe_first: bool = True
|
||||
|
||||
# Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only.
|
||||
emit_plan: bool = True
|
||||
|
||||
# (subtask spans are always stitched to a contiguous full-episode cover; not configurable.)
|
||||
|
||||
# Send a server-side ``video_url`` clip (at use_video_url_fps) instead of embedded frames.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
# Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings.
|
||||
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskAugAxesConfig:
|
||||
"""5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm /
|
||||
omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings
|
||||
when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to
|
||||
omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic."""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
synonym_paraphrase: int = 3
|
||||
omit_arm: int = 3
|
||||
omit_orientation: int = 2
|
||||
omit_grasp_method: int = 2
|
||||
combined_omissions: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsConfig:
|
||||
"""``interjections`` module: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each emits a paired (interjection, speech) row + a plan refresh at that ts.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Frame window centered on the timestamp so the VLM sees motion, not one frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class VqaConfig:
|
||||
"""``vqa`` module: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 1
|
||||
"""Consecutive frames per emission tick. The VLM grounds on the FIRST frame,
|
||||
so K>1 smears stale labels onto moved frames. Default 1 (no smear)."""
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
# True: ground VQA only on --vlm.camera_key (default: every camera).
|
||||
restrict_to_default_camera: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when
|
||||
# auto_serve=True); ``stub`` is for tests.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen3.6-27B"
|
||||
|
||||
# OpenAI-compatible endpoint; ``EMPTY`` key works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# Spawn a server if none answers api_base; False = fail fast on a remote.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command; ``{port}`` substituted per replica.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
|
||||
# Auto-serve context length (None → 32768); other vLLM flags go in serve_command.
|
||||
max_model_len: int | None = None
|
||||
|
||||
# Camera for keyframes; None → first ``observation.images.*`` key.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}).
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings (intra-process episode concurrency; distribution via HF Jobs)."""
|
||||
|
||||
# Episodes processed concurrently per phase; main knob for saturating the servers.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate`` (rewrites data shards in place)."""
|
||||
|
||||
# Hub dataset: download source when ``root`` unset; push target when push_to_hub
|
||||
# is on and ``new_repo_id`` unset.
|
||||
repo_id: str | None = None
|
||||
|
||||
# Separate push target (matches the LeRobot edit tools). Unset → push in place.
|
||||
new_repo_id: str | None = None
|
||||
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/``.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
plan: PlanConfig = field(default_factory=PlanConfig)
|
||||
interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig)
|
||||
vqa: VqaConfig = field(default_factory=VqaConfig)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Keyframe decode backend. None → ffmpeg CLI (crash-/thread-safe; torchcodec
|
||||
# SIGSEGVs under concurrent decode). Or ``"torchcodec"`` / ``"pyav"``.
|
||||
video_backend: str | None = None
|
||||
|
||||
# Upload to the Hub (new_repo_id if set, else repo_id; one must be set).
|
||||
push_to_hub: bool = False
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor runs **six phases** in dependency order:
|
||||
|
||||
phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phase 2: ``interjections`` module (interjections + speech)
|
||||
phase 3: ``plan`` plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: ``vqa`` module (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why the ``plan`` module must be re-entered after the
|
||||
``interjections`` module — to refresh ``plan`` rows at interjection
|
||||
timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotations/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all six phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
plan: Any # PlanSubtasksMemoryModule
|
||||
interjections: Any # InterjectionsAndSpeechModule
|
||||
vqa: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: ``plan`` module (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("plan", records, staging_dir, self.plan))
|
||||
# Phase 2: ``interjections`` module (interjections + speech). It
|
||||
# reads the ``plan`` module's subtask rows from the same staging
|
||||
# tree to ground the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections))
|
||||
# Phase 3: ``plan`` plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: ``vqa`` module (VQA)
|
||||
phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = load_info(root)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
merged_features = {**info.features, **language_feature_info()}
|
||||
if merged_features != info.features:
|
||||
info.features = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.tools or []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
info.tools = [*existing, SAY_TOOL_SCHEMA]
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
write_info(info, root)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in (info.tools or [])]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced.
|
||||
|
||||
The ``plan`` module owns the prompt; the ``interjections`` module
|
||||
produced the timestamps. This phase therefore calls back into the
|
||||
``plan`` module with the interjection timestamps so its existing
|
||||
prompt path is reused.
|
||||
"""
|
||||
if not self.plan.enabled or not self.interjections.enabled:
|
||||
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("interjections") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -1,498 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one decoded frame per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape
|
||||
:func:`lerobot.datasets.video_utils.decode_video_frames` returns.
|
||||
:func:`to_image_blocks` converts them to PIL only at the VLM-message
|
||||
boundary.
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(the ``plan`` and ``interjections`` modules) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` decoded frames covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. Frames are
|
||||
``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps
|
||||
them into one ``{"type":"video", "video":<list>}`` block for a
|
||||
Qwen-VL-compatible model that pools temporally itself. Empty list if
|
||||
no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for the ``plan`` module
|
||||
(subtask decomposition) and the ``interjections`` module (interjection
|
||||
scenarios) — those prompts care about *what is happening*, not which
|
||||
angle. The ``vqa`` module instead iterates over every camera in
|
||||
:attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped ``interjections`` + ``plan`` plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
# Keyframe decode backend. ``None`` uses the ffmpeg CLI — the
|
||||
# concurrency- and crash-safe default for the pipeline's threaded
|
||||
# decode. Set to ``"torchcodec"`` or ``"pyav"`` to pin an in-process
|
||||
# decoder when the build is known thread-safe.
|
||||
video_backend: str | None = None
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs the three module phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
_warned_decode_fail: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# Only ``video_keys`` are decodable here: the clip/decode paths read
|
||||
# ``videos/<key>/from_timestamp`` from episode metadata, which exists
|
||||
# only for video-stored cameras. Image-stored cameras (also in
|
||||
# ``camera_keys``) would KeyError, so restrict the list — and the
|
||||
# default — to video keys.
|
||||
keys = list(self._meta.video_keys)
|
||||
# Last-resort fallback: if metadata didn't surface any video keys but
|
||||
# the caller explicitly named a camera (``--vlm.camera_key=...``),
|
||||
# trust them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# ``_decode`` returns exactly one frame per requested timestamp,
|
||||
# or an empty list if decoding failed wholesale. A partial list
|
||||
# would mean a frame/timestamp misalignment, so only pair them up
|
||||
# when the counts match (``strict=True`` then guards regressions).
|
||||
if len(decoded) == len(miss_indices):
|
||||
with self._lock:
|
||||
for i, frame in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = frame
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = frame
|
||||
# filter out any None left over from decode failures
|
||||
return [frame for frame in out if frame is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` frames uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally. Frames are
|
||||
``torch.Tensor`` (see :meth:`frames_at`).
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
# ffmpeg is invoked by name via PATH lookup (the standard way to
|
||||
# call the CLI); the arg list is fully controlled here, not shell.
|
||||
subprocess.run(cmd, check=True, timeout=300) # nosec B607
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
"""Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors.
|
||||
|
||||
Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames`
|
||||
(torchcodec by default, PyAV fallback) rather than a bespoke decoder.
|
||||
Returns one frame per requested timestamp, or ``[]`` if decoding
|
||||
failed wholesale — callers treat ``[]`` as "no frames available".
|
||||
"""
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# Default to the ffmpeg CLI. The pipeline decodes under a 16-wide
|
||||
# ThreadPoolExecutor and the in-process decoders are unsafe there:
|
||||
# torchcodec is not thread-safe and SIGSEGVs under concurrent decode
|
||||
# (a crash no try/except can catch), PyAV can likewise segfault on
|
||||
# AV1, and lerobot's ``pyav`` backend routes through the removed
|
||||
# ``torchvision.io.VideoReader``. ``_decode_frames_ffmpeg`` shells
|
||||
# out per frame: each decode is an isolated child process, so it is
|
||||
# both crash-safe and concurrency-safe. ``video_backend`` can pin
|
||||
# ``torchcodec`` / ``pyav`` explicitly for callers that know their
|
||||
# build is safe.
|
||||
chain = [self.video_backend] if self.video_backend else ["ffmpeg"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend == "ffmpeg":
|
||||
return _decode_frames_ffmpeg(video_path, shifted)
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
)
|
||||
return list(decoded)
|
||||
except Exception as e: # noqa: PERF203
|
||||
exc = e
|
||||
|
||||
# Every backend raised. Log loudly the first time so a silent
|
||||
# vqa-module no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = self._warned_decode_fail
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
logger.warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backends=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def make_frame_provider(
|
||||
root: Path, camera_key: str | None = None, video_backend: str | None = None
|
||||
) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_ffmpeg(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` via the ffmpeg CLI.
|
||||
|
||||
Runs one ``ffmpeg`` process per timestamp, seeking with ``-ss`` and
|
||||
piping a single PNG to stdout. Unlike the in-process decoders this
|
||||
survives a hostile container: a full ffmpeg build decodes AV1 (the codec
|
||||
modern LeRobot datasets use) where torchcodec raises and PyAV can
|
||||
SIGSEGV, and a crash stays isolated to the child process — a non-zero
|
||||
exit is a catchable error, not a segfault of the whole job. Returns one
|
||||
``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import io # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
import numpy as np # noqa: PLC0415
|
||||
|
||||
frames: list[Any] = []
|
||||
for ts in timestamps:
|
||||
# ffmpeg invoked by name via PATH lookup; fully-controlled arg list, no shell.
|
||||
proc = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{max(ts, 0.0):.3f}",
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-frames:v",
|
||||
"1",
|
||||
"-f",
|
||||
"image2pipe",
|
||||
"-vcodec",
|
||||
"png",
|
||||
"pipe:1",
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
timeout=120,
|
||||
)
|
||||
if not proc.stdout:
|
||||
raise RuntimeError(f"ffmpeg returned no frame for t={ts:.3f}s of {video_path}")
|
||||
img = PIL.Image.open(io.BytesIO(proc.stdout)).convert("RGB")
|
||||
frames.append(torch.from_numpy(np.asarray(img).copy()).permute(2, 0, 1).contiguous())
|
||||
return frames
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly. Note PyAV can SIGSEGV on AV1
|
||||
streams in some builds — prefer ``_decode_frames_ffmpeg`` as the default
|
||||
fallback; this stays available behind ``video_backend="pyav"``. Returns
|
||||
one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8,
|
||||
straight from :func:`decode_video_frames`); PIL is only created here, at
|
||||
the VLM-message boundary, because the chat backends expect PIL images /
|
||||
data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched.
|
||||
"""
|
||||
if not isinstance(frame, torch.Tensor):
|
||||
return frame
|
||||
array = frame.detach().cpu()
|
||||
if array.ndim == 3 and array.shape[0] in (1, 3):
|
||||
array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
|
||||
if array.shape[-1] == 1:
|
||||
array = array.squeeze(-1)
|
||||
return PIL.Image.fromarray(array.to(torch.uint8).numpy())
|
||||
|
||||
|
||||
def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert decoded frames to Qwen-VL-compatible image content blocks."""
|
||||
return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames]
|
||||
|
||||
|
||||
def to_video_block(frames: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of decoded frames as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not frames:
|
||||
return []
|
||||
return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -1,25 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -1,248 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``vqa`` module: general VQA at a timed cadence.
|
||||
|
||||
Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K``
|
||||
consecutive frames, and every anchored frame gets its own VQA pair. Each
|
||||
pair is grounded on that single anchor frame — there is no per-pair frame
|
||||
window. For datasets with multiple cameras, every anchored frame produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's ``vqa`` table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import VqaConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: VqaConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
_warned_no_camera: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not self._warned_no_camera:
|
||||
logging.getLogger(__name__).warning(
|
||||
"vqa module found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("vqa", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("vqa", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras the ``vqa`` module should iterate per anchored frame.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
|
||||
When ``config.restrict_to_default_camera`` is set, VQA grounds on
|
||||
only the provider's default camera (the single ``--vlm.camera_key``
|
||||
stream), matching the plan / interjection modules so the whole
|
||||
pipeline focuses on one view.
|
||||
"""
|
||||
all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
if getattr(self.config, "restrict_to_default_camera", False):
|
||||
default = getattr(self.frame_provider, "camera_key", None)
|
||||
if default and default in all_cameras:
|
||||
return [default]
|
||||
# ``restrict_to_default_camera`` is set but the configured default
|
||||
# isn't one the provider exposes. Returning it anyway would make
|
||||
# ``_decode`` raise a KeyError deep in frame extraction, so warn and
|
||||
# fall through to every available camera instead.
|
||||
if default:
|
||||
logging.getLogger(__name__).warning(
|
||||
"restrict_to_default_camera is set but camera_key=%r is not in the "
|
||||
"provider's cameras %s; grounding VQA on all available cameras instead.",
|
||||
default,
|
||||
all_cameras,
|
||||
)
|
||||
return all_cameras
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
||||
interjection timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import InterjectionsConfig
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: InterjectionsConfig
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull the ``plan`` module's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("interjections", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("interjections_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("interjections_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
first_ts = float(frame_timestamps[0])
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(first_ts, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -1,712 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import PlanConfig
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
the ``interjections`` module completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: PlanConfig
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Task driving every plan-module prompt: canonical episode_task, or a
|
||||
# video-derived one when it's empty/placeholder (see derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# task_aug rows at t=0: phrasings the renderer rotates ${task} through.
|
||||
# Either the structured 5-axis taxonomy (task_aug_axes.enabled) or
|
||||
# free-form n_task_rephrasings; the effective task is always emitted
|
||||
# first so the rotation covers the source-of-truth phrasing.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
variants: list[str] | None = None
|
||||
if self.config.task_aug_axes.enabled and effective_task:
|
||||
variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes)
|
||||
elif self.config.n_task_rephrasings > 0 and effective_task:
|
||||
variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
if variants is not None:
|
||||
rows.extend(self._task_aug_rows([effective_task, *variants], t0))
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# Plan rows at every subtask boundary (incl. t=0). The plan is a
|
||||
# numbered list of still-todo subtasks, so re-emitting at each
|
||||
# boundary makes it shrink as work progresses — ${plan} at frame t is
|
||||
# exactly what's left to do.
|
||||
if self.config.emit_plan:
|
||||
for span in subtask_spans:
|
||||
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
plan_text = self._generate_plan(
|
||||
record, subtask_spans, refresh_t=boundary_t, task=effective_task
|
||||
)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(boundary_t),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("plan", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives the ``plan`` module for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
@staticmethod
|
||||
def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]:
|
||||
"""Build deduplicated ``task_aug`` rows (role=user) at ``t0``."""
|
||||
seen: set[str] = set()
|
||||
rows: list[dict[str, Any]] = []
|
||||
for phrasing in phrasings:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None}
|
||||
)
|
||||
return rows
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers — every plan-module prompt follows the same shape:
|
||||
# build messages → single VLM call → pull a named field.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prompt: str,
|
||||
window: tuple[float, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""User message combining the (optionally windowed) video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured 5-axis task augmentation (EgoMimic-style taxonomy)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]:
|
||||
"""One VLM call → variants along the 5-axis taxonomy.
|
||||
|
||||
Variants from all axes are flattened into a single list (the
|
||||
downstream pipeline doesn't need to know about the per-axis
|
||||
bucketing — every variant becomes a ``task_aug`` row). Order
|
||||
is preserved for reproducibility: synonym_paraphrase first,
|
||||
then omit_arm, then omit_orientation, then omit_grasp_method,
|
||||
then combined_omissions.
|
||||
"""
|
||||
if not base_task:
|
||||
return []
|
||||
prompt = load_prompt("plan_task_aug_axes").format(
|
||||
base_task=base_task,
|
||||
n_synonym=axes_cfg.synonym_paraphrase,
|
||||
n_omit_arm=axes_cfg.omit_arm,
|
||||
n_omit_orientation=axes_cfg.omit_orientation,
|
||||
n_omit_grasp_method=axes_cfg.omit_grasp_method,
|
||||
n_combined=axes_cfg.combined_omissions,
|
||||
)
|
||||
result = self.vlm.generate_json([self._text_message(prompt)])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
ordered_axes = (
|
||||
"synonym_paraphrase",
|
||||
"omit_arm",
|
||||
"omit_orientation",
|
||||
"omit_grasp_method",
|
||||
"combined_omissions",
|
||||
)
|
||||
flat: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for axis in ordered_axes:
|
||||
entries = result.get(axis)
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
key = item.strip().strip('"').strip("'")
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
flat.append(key)
|
||||
return flat
|
||||
|
||||
def _episode_video_block(
|
||||
self, record: EpisodeRecord, window: tuple[float, float] | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Video block for the segmentation / describe prompts.
|
||||
|
||||
Always returns a block that actually carries the video. When
|
||||
``use_video_url`` is set we try the server-side ``video_url``
|
||||
path first, but if clip extraction fails we FALL BACK to
|
||||
decoding + embedding frames rather than returning an empty
|
||||
block — an empty block would leave the VLM with no visual
|
||||
grounding at all and it would hallucinate subtasks purely from
|
||||
the task text.
|
||||
|
||||
When ``window=(w0, w1)`` is given (windowed subtask generation,
|
||||
``subtask_window_seconds > 0``), embed frames sampled at the FIXED
|
||||
``frames_per_second`` rate within ``[w0, w1]`` — constant temporal
|
||||
density regardless of episode length, so long episodes are split
|
||||
into windows rather than subsampled to a sparse 32-frame whole-
|
||||
episode view. The ``video_url`` path is skipped for windows (it is
|
||||
a whole-episode clip). ``max_video_frames`` still caps each window
|
||||
as a context-budget safety net.
|
||||
"""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if window is not None:
|
||||
w0, w1 = float(window[0]), float(window[1])
|
||||
dur = max(0.0, w1 - w0)
|
||||
n = max(1, int(round(dur * self.config.frames_per_second)) + 1)
|
||||
n = min(n, self.config.max_video_frames)
|
||||
if n <= 1 or dur <= 0.0:
|
||||
timestamps = [0.5 * (w0 + w1)]
|
||||
else:
|
||||
step = dur / (n - 1)
|
||||
timestamps = [w0 + i * step for i in range(n)]
|
||||
return to_video_block(self.frame_provider.frames_at(record, timestamps))
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
if clip is not None:
|
||||
return to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
logger.warning(
|
||||
"episode %d: video_url clip extraction failed — falling back to "
|
||||
"embedded frames so the VLM still sees the demonstration",
|
||||
record.episode_index,
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections (event-driven). The
|
||||
interjection text is forwarded into the prompt so the refreshed plan
|
||||
reflects the user's correction.
|
||||
"""
|
||||
if not self.config.emit_plan:
|
||||
return
|
||||
existing = staging.read("plan")
|
||||
# Pass the last frame timestamp so the final span is closed (else its
|
||||
# end == start, zero duration, and a refresh inside it is missed).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("plan", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Generate subtask spans, optionally via a multi-call quality chain.
|
||||
|
||||
Single call (default): watch video → emit subtask JSON.
|
||||
|
||||
Multi-call (opt-in, higher quality, more VLM calls):
|
||||
1. ``subtask_describe_first`` — a grounding pass that narrates
|
||||
ONLY what is visible (no JSON commitment to subtasks yet);
|
||||
its description is injected into the segmentation prompt so
|
||||
the model segments its own grounded observations instead of
|
||||
pattern-matching the task text.
|
||||
2. segmentation — emit subtask JSON (as before).
|
||||
"""
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
effective_task = task if task is not None else record.episode_task
|
||||
|
||||
# ---- Windowed path (constant temporal density) ---------------
|
||||
# If subtask_window_seconds > 0 and the episode exceeds one window,
|
||||
# process fixed-length windows so the VLM always sees
|
||||
# frames_per_second density; results are merged + stitched.
|
||||
window_s = float(getattr(self.config, "subtask_window_seconds", 0.0) or 0.0)
|
||||
if window_s > 0.0 and episode_duration > window_s:
|
||||
return self._generate_subtasks_windowed(record, effective_task, window_s)
|
||||
|
||||
# ---- Pass 1 (optional): grounding description ----------------
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, effective_task)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the video) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
# ---- Pass 2: segmentation ------------------------------------
|
||||
prompt = load_prompt("plan_subtasks").format(
|
||||
episode_task=effective_task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
cleaned = self._clean_spans(spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM can start after t0 or leave gaps, so frames fall through
|
||||
# with no active subtask. Always stitch into a contiguous
|
||||
# [t0, t_last] cover.
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _generate_subtasks_windowed(
|
||||
self, record: EpisodeRecord, task: str, window_s: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Subtask generation in fixed-length windows at constant fps.
|
||||
|
||||
Splits ``[t0, t_last]`` into consecutive windows of ``window_s``
|
||||
seconds, runs the describe -> segment chain on each window's own
|
||||
frames (sampled at ``frames_per_second``), offsets
|
||||
each window's spans back to absolute episode time, then merges +
|
||||
stitches into a contiguous whole-episode cover.
|
||||
"""
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
all_spans: list[dict[str, Any]] = []
|
||||
w0 = t0
|
||||
n_windows = 0
|
||||
while w0 < t_last - 1e-6:
|
||||
w1 = min(w0 + window_s, t_last)
|
||||
all_spans.extend(self._subtasks_for_window(record, task, w0, w1))
|
||||
n_windows += 1
|
||||
w0 = w1
|
||||
logger.info(
|
||||
"episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans",
|
||||
record.episode_index,
|
||||
n_windows,
|
||||
window_s,
|
||||
len(all_spans),
|
||||
)
|
||||
# Merge across windows: clamp to the absolute episode, sort, and
|
||||
# frame-snap to distinct starts (handles any boundary collisions).
|
||||
cleaned = self._clean_spans(all_spans, record)
|
||||
if not cleaned:
|
||||
return []
|
||||
return self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
def _subtasks_for_window(
|
||||
self, record: EpisodeRecord, task: str, w0: float, w1: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run describe -> segment on one ``[w0, w1]`` window.
|
||||
|
||||
The model works in window-RELATIVE time ``[0, L]`` (it perceives
|
||||
the window as a clip starting at 0); spans are offset back to
|
||||
absolute ``[w0, w1]`` before returning.
|
||||
"""
|
||||
window = (w0, w1)
|
||||
win_len = max(0.0, w1 - w0)
|
||||
|
||||
observation_block = ""
|
||||
if getattr(self.config, "subtask_describe_first", False):
|
||||
description = self._describe_episode(record, task, window=window)
|
||||
if description:
|
||||
observation_block = (
|
||||
"You watched this video clip and described, chronologically, "
|
||||
"ONLY what the robot actually does:\n"
|
||||
f'"""{description}"""\n\n'
|
||||
"Segment THAT grounded description (cross-checked against "
|
||||
"the clip) into atomic subtasks. Do not introduce any "
|
||||
"action that is not in your description above.\n\n"
|
||||
)
|
||||
|
||||
prompt = load_prompt("plan_subtasks").format(
|
||||
episode_task=task,
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{win_len:.3f}",
|
||||
observation_block=observation_block,
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
|
||||
# Window-relative clamp; no frame-snap dedupe yet (done on the
|
||||
# merged absolute set).
|
||||
cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# Offset window-relative spans back to absolute episode time.
|
||||
for s in cleaned:
|
||||
s["start"] = w0 + float(s["start"])
|
||||
s["end"] = w0 + float(s["end"])
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
def _clean_spans(
|
||||
self,
|
||||
spans: Any,
|
||||
record: EpisodeRecord,
|
||||
bounds: tuple[float, float] | None = None,
|
||||
dedupe: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows.
|
||||
|
||||
``bounds`` overrides the clamp range — pass the window's
|
||||
``(w_lo, w_hi)`` when cleaning window-relative spans, or leave
|
||||
``None`` to clamp to the whole episode ``[t0, t_last]``.
|
||||
``dedupe`` runs the frame-snap distinct-start step; skip it for
|
||||
window-relative spans (frame snapping is done once on the merged,
|
||||
absolute-time set).
|
||||
"""
|
||||
if not spans:
|
||||
return []
|
||||
if bounds is not None:
|
||||
lo, hi = float(bounds[0]), float(bounds[1])
|
||||
else:
|
||||
lo = record.frame_timestamps[0]
|
||||
hi = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(lo, min(start, hi))
|
||||
end = max(lo, min(end, hi))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
if dedupe:
|
||||
return self._dedupe_starts_to_distinct_frames(cleaned, record)
|
||||
return cleaned
|
||||
|
||||
def _describe_episode(
|
||||
self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None
|
||||
) -> str:
|
||||
"""Grounding pass: free-form chronological description of the (windowed) video."""
|
||||
prompt = load_prompt("plan_subtask_describe").format(episode_task=task)
|
||||
text = self._vlm_field(self._video_message(record, prompt, window=window), "description")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else ""
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_starts_to_distinct_frames(
|
||||
spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Bump same-frame subtask starts onto distinct frames.
|
||||
|
||||
Two consecutive VLM spans whose ``start`` rounds to the same
|
||||
source frame (after :func:`snap_to_frame`) would otherwise emit
|
||||
two ``style=subtask`` rows at the identical persistent
|
||||
timestamp. The training-time renderer's ``active_at(t,
|
||||
style=subtask)`` resolver can't disambiguate that and raises
|
||||
``Ambiguous resolver for style='subtask'``.
|
||||
|
||||
Walk the (sorted-by-start) spans, snap each to its frame, and
|
||||
if the snapped frame is already taken push the span onto the
|
||||
next unused frame so both subtasks survive on distinct
|
||||
timestamps. If the episode ends before a free frame is found,
|
||||
the trailing span is dropped with a warning — better than
|
||||
poisoning the render.
|
||||
"""
|
||||
if not spans:
|
||||
return spans
|
||||
frames = record.frame_timestamps
|
||||
if not frames:
|
||||
return spans
|
||||
used: set[float] = set()
|
||||
out: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
ts = snap_to_frame(span["start"], frames)
|
||||
if ts in used:
|
||||
next_ts = next((f for f in frames if f > ts and f not in used), None)
|
||||
if next_ts is None:
|
||||
logger.warning(
|
||||
"episode %d: subtask %r snapped to occupied frame "
|
||||
"%.3f and no free later frame exists — dropping",
|
||||
record.episode_index,
|
||||
span.get("text"),
|
||||
ts,
|
||||
)
|
||||
continue
|
||||
ts = next_ts
|
||||
used.add(ts)
|
||||
new_span = {**span, "start": ts}
|
||||
if float(new_span.get("end", ts)) < ts:
|
||||
new_span["end"] = ts
|
||||
out.append(new_span)
|
||||
return out
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord, # noqa: ARG002 (kept for signature stability)
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None, # noqa: ARG002
|
||||
task: str | None = None, # noqa: ARG002
|
||||
) -> str | None:
|
||||
"""Deterministic plan = numbered list of *still-todo* subtasks.
|
||||
|
||||
No VLM call: a plain numbered list keeps the plan aligned with the
|
||||
upcoming subtasks (the old VLM "compact hierarchical plan" prompt
|
||||
cost a round-trip per episode/refresh and could diverge).
|
||||
|
||||
1. <subtask 1>
|
||||
2. <subtask 2>
|
||||
|
||||
On a refresh at ``refresh_t`` (from ``run_plan_updates`` on
|
||||
interjections, and ``run_episode`` at each boundary), only subtasks
|
||||
starting at or after ``refresh_t`` are included — so it always
|
||||
describes what's left.
|
||||
"""
|
||||
if not subtask_spans:
|
||||
return None
|
||||
remaining = [
|
||||
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
|
||||
]
|
||||
if not remaining:
|
||||
# Past the last subtask boundary on a late refresh — nothing
|
||||
# left to plan; emit None so the caller skips the row.
|
||||
return None
|
||||
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("plan_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -1,12 +0,0 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -1,36 +0,0 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the
|
||||
robot has accomplished so far — the running story it would tell itself.
|
||||
|
||||
Authoring rules:
|
||||
- First person, past tense. Every sentence starts with "I": "I picked
|
||||
up...", "I opened...", "I moved to...".
|
||||
- One or two short sentences. Extend the previous memory with the
|
||||
just-completed subtask; do not rewrite it from scratch.
|
||||
- Keep WHAT happened (functional outcomes — where items went, how many),
|
||||
drop HOW (grasp details, motions).
|
||||
- Compress completed steps and drop object attributes (colors, exact
|
||||
counts) once they no longer affect the remaining subtasks.
|
||||
|
||||
Example (MEM, Torne 2026):
|
||||
Before: "I prepared the pot and got the potatoes, milk, and butter. I
|
||||
moved to the drawer."
|
||||
After: "I prepared the pot and got the ingredients. I opened the
|
||||
drawer with the masher."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short first-person past-tense sentences>" }}
|
||||
@@ -1,27 +0,0 @@
|
||||
You are watching a teleoperated robot demonstration from a single
|
||||
camera. The user asked the robot to: "{episode_task}"
|
||||
|
||||
This is an OBSERVATION pass. Watch the entire clip and describe, in
|
||||
chronological order, ONLY what the robot physically does — the concrete
|
||||
motions, approaches, contacts, grasps, releases, and relocations you can
|
||||
actually SEE in the frames.
|
||||
|
||||
Hard rules:
|
||||
- Describe only motion visible in the video. Do NOT use the task
|
||||
instruction to guess steps that aren't shown. The instruction is the
|
||||
goal; the video is ground truth.
|
||||
- Do NOT segment into named subtasks yet and do NOT output JSON beyond
|
||||
the single field below. Just narrate what happens.
|
||||
- Give an approximate timestamp (in seconds) for each distinct event,
|
||||
e.g. "0.0-1.4s: the base drives forward toward the stove".
|
||||
- Do NOT invent objects, grasps, destinations, or steps. If the robot
|
||||
only does one thing (e.g. it just navigates and the clip ends), say
|
||||
exactly that and nothing more.
|
||||
- Be concrete and literal. "the gripper closes on the mug" — not "the
|
||||
robot prepares to make coffee".
|
||||
|
||||
Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"description": "<chronological, timestamped description of ONLY what is visible>"
|
||||
}}
|
||||
@@ -1,112 +0,0 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
{observation_block}GROUNDING — read this first, it overrides everything below:
|
||||
- Label ONLY what the robot actually does in the video. Every subtask
|
||||
you emit must correspond to motion you can SEE in specific frames.
|
||||
- Do NOT invent, anticipate, or pad. If the robot only does one thing
|
||||
(e.g. it just navigates to a location and the clip ends), emit
|
||||
EXACTLY ONE subtask. Many demonstrations are a single atomic skill.
|
||||
- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer
|
||||
subtasks than the ceiling is not just allowed, it is expected for
|
||||
short / atomic demonstrations. One correct subtask is far better
|
||||
than several invented ones.
|
||||
- If the video does not clearly show the action implied by the task,
|
||||
describe what you actually see — do NOT fabricate the task's steps
|
||||
from the instruction text. The instruction tells you the goal; the
|
||||
VIDEO is the ground truth for what happened.
|
||||
|
||||
Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts:
|
||||
|
||||
- Each subtask = one COMPOSITE atomic skill the low-level policy can
|
||||
execute end-to-end. A "skill" bundles its own approach motion with
|
||||
its terminal action — do NOT split the approach off as its own
|
||||
subtask. The whole-arm policy already learns to reach as part of
|
||||
every manipulation primitive.
|
||||
- Write each subtask as an IMPERATIVE COMMAND, starting with one of
|
||||
these verbs (extend only when none fits):
|
||||
pick up <obj> — approach + grasp + lift in one subtask
|
||||
put <obj> on/in <loc> — transport + release in one subtask
|
||||
place <obj> on/in <loc> — synonym of "put"; pick one and stay consistent
|
||||
push <obj> — contact + linear shove
|
||||
pull <obj> — contact + linear retract
|
||||
turn <knob/dial/handle> — rotary actuation
|
||||
press <button> — single-press contact
|
||||
open <drawer/door/lid> — full open motion
|
||||
close <drawer/door/lid> — full close motion
|
||||
pour <src> into <dst> — tilt + flow
|
||||
insert <obj> into <slot>— alignment + push-fit
|
||||
go to <loc> — ONLY when no grasp / actuation follows
|
||||
(e.g. a pure relocation between phases).
|
||||
If the next subtask grasps something at
|
||||
that location, drop "go to ..." and just
|
||||
write "pick up ..." instead.
|
||||
- Forbidden ultra-fine splits — the VLM is NOT allowed to emit these
|
||||
as standalone subtasks; fold them into the parent composite:
|
||||
"move to X" → fold into "pick up X" (or whatever follows)
|
||||
"reach for X" → fold into "pick up X"
|
||||
"grasp X" → fold into "pick up X"
|
||||
"lift X" → fold into "pick up X" (or "put X on Y" if it's
|
||||
the transport phase of a place)
|
||||
"release X" → fold into "put X on Y" (or "place X in Y")
|
||||
- Keep it SHORT — a verb phrase, not a sentence. Drop articles
|
||||
("the", "a") and adverbs ("carefully", "slowly"). Add a "how"
|
||||
detail (which hand, which grasp point) ONLY when it is needed to
|
||||
disambiguate. Every subtask must begin with one of the verbs
|
||||
above (no leading nouns, no "then", no "first").
|
||||
- NEVER use third person. Never write "the robot", "the arm", "the
|
||||
gripper moves", "it picks up" — the robot is implied. Command it,
|
||||
do not describe it.
|
||||
- Use the exact object nouns from the task above. If the task says
|
||||
"cube", every subtask says "cube" — never switch to "block". If it
|
||||
says "box", never switch to "bin"/"container". Keep vocabulary
|
||||
consistent across the whole episode.
|
||||
- Good: "pick up blue cube", "put blue cube in box", "open drawer",
|
||||
"turn red knob", "press start button", "go to sink".
|
||||
- Bad: "move to blue cube" (approach as its own subtask — forbidden,
|
||||
must be folded into "pick up blue cube"); "the robot arm moves
|
||||
towards the blue cube" (third person, too long); "carefully pick
|
||||
up the cube" (adverb, article); "release the yellow block"
|
||||
("block" when the task said "cube", and "release" must be folded
|
||||
into a "put"/"place" subtask).
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds. If a
|
||||
candidate span would be shorter, merge it into its neighbour
|
||||
rather than emitting it.
|
||||
- Do not exceed {max_steps} subtasks total. Fewer, larger composites
|
||||
are preferred over many micro-steps.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
SPECIAL CASES — verb disambiguation (each rule is narrowly visual and
|
||||
fires ONLY on the spatial situation it names; it must not change how you
|
||||
label any other situation):
|
||||
- STACK vs PUT: if an object is placed ON TOP OF another specific object
|
||||
(not on a flat table / shelf / counter), use "stack ... on ...", not
|
||||
"put". "stack blue book on green book", NOT "put blue book on table".
|
||||
- INSERT vs PUT: if an object goes INTO a fitted slot / hole / socket /
|
||||
receptacle (push-fit), use "insert ... into ...", not "put".
|
||||
- RETRIEVE/PICK-UP vs PUT (direction): watch the gripper. If it CLOSES
|
||||
on the object and the object moves WITH the hand, it is "pick up" /
|
||||
"retrieve" (object leaves its location). If the gripper OPENS and the
|
||||
object stays where the hand left it, it is "put" / "place" (object
|
||||
arrives at a location). Decide by which way the object moves, not by
|
||||
where the hand ends up.
|
||||
- POUR vs PUT: only use "pour" when the source is tilted and contents
|
||||
flow out; moving a full container without tilting is "put"/"place".
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<short imperative verb phrase>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,67 +0,0 @@
|
||||
You are generating structured augmentations of a robot task instruction
|
||||
for training a language-conditioned policy. Unlike free-form rephrasing,
|
||||
your variants follow a NAMED 5-axis taxonomy — each axis omits or varies
|
||||
a specific element of the task while preserving its meaning.
|
||||
|
||||
Original task: "{base_task}"
|
||||
|
||||
Produce variants along five named axes. Each axis has a target count.
|
||||
The whole batch should expose the policy to maximum linguistic diversity
|
||||
WITHOUT changing what the robot is supposed to do.
|
||||
|
||||
Axes and target counts:
|
||||
|
||||
synonym_paraphrase ({n_synonym}):
|
||||
Different wording / verbs / sentence structure. ALL information
|
||||
from the original task is preserved — same object, same arm
|
||||
specification if present, same orientation if present, same grasp
|
||||
if present.
|
||||
|
||||
omit_arm ({n_omit_arm}):
|
||||
Drop the left/right/both arm specification from the task. Skip
|
||||
entirely (emit 0 entries) if the original task does NOT mention an
|
||||
arm. Do not invent an arm specification just to omit it.
|
||||
|
||||
omit_orientation ({n_omit_orientation}):
|
||||
Drop orientation cues (upright, sideways, facing the user,
|
||||
long-edge-first, etc.). Skip entirely if no orientation cue is
|
||||
present in the original task.
|
||||
|
||||
omit_grasp_method ({n_omit_grasp_method}):
|
||||
Drop the grip / grasp method specification (pinch, wrap, hold by
|
||||
the rim, etc.). Skip entirely if no grasp method is mentioned.
|
||||
|
||||
combined_omissions ({n_combined}):
|
||||
Combine TWO of the above omissions simultaneously (e.g. drop both
|
||||
arm and orientation). Skip entirely if fewer than two of (arm,
|
||||
orientation, grasp_method) appear in the original task.
|
||||
|
||||
Hard rules:
|
||||
- Each variant MUST preserve the core action, the target object, AND
|
||||
the goal / destination. Do not change which object is involved, where
|
||||
it goes, or the high-level action. "Navigate to the stove" may become
|
||||
"go to the stove" or "head over to the stove" — it must NEVER become
|
||||
"wander around the kitchen", "explore the room", or anything that
|
||||
drops or generalises the stove destination. If you cannot vary the
|
||||
wording without changing the goal, emit fewer variants.
|
||||
- Only the FIVE listed elements (wording, arm, orientation, grasp
|
||||
method, or a combination) may be varied or omitted. The verb's
|
||||
meaning, the object, and the destination are fixed.
|
||||
- Each variant is plain prose, no markdown, no quotes, no list numbers.
|
||||
- Each variant must be DISTINCT from every other variant in the entire
|
||||
output, both within and across axes. Near-duplicates are not allowed.
|
||||
- If an axis cannot reach its target count because the original task
|
||||
lacks the omittable element, emit fewer entries — do NOT pad the
|
||||
axis with paraphrases that belong to a different axis.
|
||||
- Variants should not all start with verbs — vary sentence structure
|
||||
(some imperative, some polite request, some question).
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"synonym_paraphrase": ["<v1>", "<v2>", ...],
|
||||
"omit_arm": ["<v1>", "<v2>", ...],
|
||||
"omit_orientation": ["<v1>", ...],
|
||||
"omit_grasp_method": ["<v1>", ...],
|
||||
"combined_omissions": ["<v1>", ...]
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -1,17 +0,0 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -1,32 +0,0 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -1,216 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import load_tasks
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by the ``plan`` module (plan-update pass) and the
|
||||
``interjections`` module (interjection anchoring), which both need the
|
||||
same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame: event rows must land on an
|
||||
exact frame, otherwise the per-frame event lookup the writer does
|
||||
would never match them.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
"""Map ``task_index -> task`` from ``meta/tasks.parquet``.
|
||||
|
||||
Returns an empty dict when the file is absent — the task description is
|
||||
derived later from the video if needed. Reuses the library-level
|
||||
:func:`lerobot.datasets.io_utils.load_tasks`, which returns the tasks
|
||||
frame indexed by task string with a ``task_index`` column.
|
||||
"""
|
||||
if not (root / DEFAULT_TASKS_PATH).exists():
|
||||
return {}
|
||||
tasks = load_tasks(root)
|
||||
return {int(idx): str(task) for task, idx in zip(tasks.index, tasks["task_index"], strict=True)}
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"plan",
|
||||
"interjections",
|
||||
"vqa",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
@@ -1,332 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after all three modules have written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
|
||||
# ``_check_column_routing`` already recorded any unknown-style error;
|
||||
# don't let the same ``column_for_style`` lookup raise here uncaught.
|
||||
try:
|
||||
column = column_for_style(row.get("style"))
|
||||
except ValueError:
|
||||
continue
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
|
||||
return
|
||||
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "plan" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=plan emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"interjections", "vqa"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -1,599 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output from the server,
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.request
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client.
|
||||
|
||||
Only the ``openai`` backend is supported for now. The shipped workflow
|
||||
is Hugging Face Jobs (``examples/annotations/run_hf_job.py``): it boots
|
||||
a vLLM server inside the ``vllm/vllm-openai`` image and the pipeline
|
||||
talks to it over the OpenAI-compatible API (``--vlm.backend=openai``,
|
||||
optionally auto-spawning the server via ``auto_serve`` /
|
||||
``serve_command``). The former in-process ``vllm`` / ``transformers``
|
||||
backends were removed to keep the support surface to the HF Jobs path.
|
||||
|
||||
For ``stub``, construct :class:`StubVlmClient` directly with a responder
|
||||
callable; it is rejected here to make accidental misuse obvious.
|
||||
"""
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend in {"vllm", "transformers"}:
|
||||
raise ValueError(
|
||||
f"backend={config.backend!r} (in-process local model) is not supported for now — "
|
||||
"only backend='openai' (the Hugging Face Jobs flow) is. Run the pipeline via "
|
||||
"examples/annotations/run_hf_job.py, which serves the model with vLLM in the "
|
||||
"vllm/vllm-openai image and talks to it over the OpenAI-compatible API."
|
||||
)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
@@ -1,341 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants live in lerobot.datasets.language — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, Any], style: str | None, *, with_timestamp: bool) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the language-column struct shape.
|
||||
|
||||
Key order matches ``PERSISTENT_ROW_FIELDS`` / ``EVENT_ROW_FIELDS`` — the
|
||||
writer infers the parquet struct schema from insertion order, so
|
||||
``timestamp`` (persistent rows only) sits between ``style`` and ``camera``.
|
||||
"""
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
out: dict[str, Any] = {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
}
|
||||
if with_timestamp:
|
||||
out["timestamp"] = float(row["timestamp"])
|
||||
out["camera"] = None if camera is None else str(camera)
|
||||
out["tool_calls"] = _normalize_tool_calls(row.get("tool_calls"))
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Friendly error from the writer instead of a raw KeyError below;
|
||||
# the validator doesn't check ``role`` yet.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=True)
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
return _normalize_row(row, style, with_timestamp=False)
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — `tests/datasets/test_language.py`
|
||||
# exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
# Utilities
|
||||
########################################################################################
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
@@ -244,72 +243,3 @@ def sanity_check_dataset_robot_compatibility(
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Teleoperator smooth handover helpers
|
||||
# NOTE(Maxime): These functions use minimal type hints to maintain compatibility with utils
|
||||
# being a root module.
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleop_supports_feedback(teleop) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
|
||||
Actuated teleops (e.g. SO-101, OpenArmMini) have non-empty ``feedback_features``
|
||||
and expose ``enable_torque`` / ``disable_torque`` motor-control methods.
|
||||
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fps: int = 30) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback (i.e. have non-empty
|
||||
``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
``target_pos`` is expected to be in the teleop's action/feedback key space.
|
||||
For homogeneous setups (e.g. SO-101 leader + SO-101 follower) this matches
|
||||
the robot action key space directly.
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds; during this time the
|
||||
follower robot does not receive new actions, which could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def follower_smooth_move_to(
|
||||
robot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm to
|
||||
the follower, the follower is brought to the teleop's current pose so the
|
||||
robot meets the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
Both ``current`` and ``target`` must be in the robot action key space
|
||||
(i.e. the output of ``robot_action_processor``).
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
@@ -205,149 +205,3 @@ class WandBLogger:
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
def log_training_examples(
|
||||
self,
|
||||
batch: dict,
|
||||
step: int,
|
||||
*,
|
||||
camera_keys: list[str],
|
||||
n_samples: int = 4,
|
||||
policy=None,
|
||||
predict_actions: bool = False,
|
||||
mode: str = "train",
|
||||
) -> None:
|
||||
"""Push a ``wandb.Table`` of training-example rows for the current batch.
|
||||
|
||||
Each row is one batch element with:
|
||||
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
|
||||
HWC, uint8 or float in [0,1] — auto-detected),
|
||||
* any text fields present in the batch (``task`` / ``subtask`` /
|
||||
``memory`` / ``instruction``),
|
||||
* ground-truth action first/last frame (the action chunk's
|
||||
endpoints — gives a quick sense of trajectory direction),
|
||||
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
|
||||
``predict_action_chunk`` first/last frame alongside.
|
||||
|
||||
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
|
||||
training loop calls it once every N steps. Cheap to keep on: with
|
||||
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
|
||||
enabled) run one extra inference forward pass.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import numpy as np # noqa: PLC0415
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
# Batch size — first tensor-like value wins.
|
||||
bsz = next(
|
||||
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
|
||||
None,
|
||||
)
|
||||
if not bsz:
|
||||
return
|
||||
n = min(int(n_samples), bsz)
|
||||
|
||||
# Optional predicted-action forward pass on the first n samples.
|
||||
pred_actions: np.ndarray | None = None
|
||||
if predict_actions and policy is not None:
|
||||
was_training = policy.training
|
||||
try:
|
||||
policy.eval()
|
||||
sub_batch = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sub_batch[k] = v[:n]
|
||||
elif isinstance(v, (list, tuple)):
|
||||
sub_batch[k] = list(v[:n])
|
||||
else:
|
||||
sub_batch[k] = v
|
||||
with torch.no_grad():
|
||||
pred = policy.predict_action_chunk(sub_batch)
|
||||
pred_actions = pred.detach().cpu().float().numpy()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: predict_action_chunk failed (%s) — "
|
||||
"skipping predicted-action columns",
|
||||
exc,
|
||||
)
|
||||
pred_actions = None
|
||||
finally:
|
||||
if was_training:
|
||||
policy.train()
|
||||
|
||||
present_cameras = [c for c in camera_keys if c in batch]
|
||||
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
|
||||
|
||||
columns = ["sample"]
|
||||
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
|
||||
columns.extend(text_keys)
|
||||
columns.append("gt_action_first")
|
||||
columns.append("gt_action_last")
|
||||
if pred_actions is not None:
|
||||
columns.append("pred_action_first")
|
||||
columns.append("pred_action_last")
|
||||
|
||||
table = self._wandb.Table(columns=columns)
|
||||
|
||||
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
|
||||
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
|
||||
if t.ndim == 4:
|
||||
t = t[0]
|
||||
# CHW -> HWC.
|
||||
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
|
||||
t = t.permute(1, 2, 0)
|
||||
arr = t.detach().cpu().float().numpy()
|
||||
if arr.size and float(arr.max()) <= 1.5:
|
||||
arr = arr * 255.0
|
||||
return np.clip(arr, 0, 255).astype(np.uint8)
|
||||
|
||||
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
|
||||
arr = a.detach().cpu().float().numpy()
|
||||
if arr.ndim == 2: # (T, D)
|
||||
return (
|
||||
str(np.round(arr[0], 3).tolist()),
|
||||
str(np.round(arr[-1], 3).tolist()),
|
||||
)
|
||||
if arr.ndim == 1:
|
||||
rounded = np.round(arr, 3).tolist()
|
||||
return (str(rounded), str(rounded))
|
||||
return (str(arr.tolist()), str(arr.tolist()))
|
||||
|
||||
for i in range(n):
|
||||
row: list = [i]
|
||||
for cam in present_cameras:
|
||||
try:
|
||||
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: camera %s sample %d failed (%s)",
|
||||
cam,
|
||||
i,
|
||||
exc,
|
||||
)
|
||||
row.append(None)
|
||||
for tk in text_keys:
|
||||
v = batch[tk]
|
||||
if isinstance(v, (list, tuple)):
|
||||
row.append(str(v[i]) if i < len(v) else "")
|
||||
else:
|
||||
row.append(str(v))
|
||||
action = batch.get("action")
|
||||
if isinstance(action, torch.Tensor) and action.ndim >= 1:
|
||||
first, last = _action_endpoints(action[i])
|
||||
row.append(first)
|
||||
row.append(last)
|
||||
else:
|
||||
row.append("")
|
||||
row.append("")
|
||||
if pred_actions is not None:
|
||||
p = torch.from_numpy(pred_actions[i])
|
||||
pfirst, plast = _action_endpoints(p)
|
||||
row.append(pfirst)
|
||||
row.append(plast)
|
||||
table.add_data(*row)
|
||||
|
||||
self._wandb.log({f"{mode}/examples": table}, step=step)
|
||||
|
||||
@@ -62,72 +62,6 @@ class WandBConfig:
|
||||
run_id: str | None = None
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
add_tags: bool = True # If True, save configuration as tags in the WandB run.
|
||||
# Periodic training-example dump (independent of ``log_freq``). When > 0,
|
||||
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
|
||||
# with one row per sampled batch element containing each camera view
|
||||
# (rendered as ``wandb.Image``), any text fields present in the batch
|
||||
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
|
||||
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
|
||||
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
|
||||
# are unaffected.
|
||||
log_examples_freq: int = 5000
|
||||
# Number of batch elements to include in each example dump.
|
||||
log_examples_n: int = 4
|
||||
# If True (default), also run ``policy.predict_action_chunk`` on the logged
|
||||
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
|
||||
# columns to the table. Costs one extra forward pass per dump — negligible
|
||||
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
|
||||
# implement ``predict_action_chunk`` or you want to skip the extra forward.
|
||||
log_examples_predict_actions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EMAConfig:
|
||||
"""Exponential Moving Average of trainable policy parameters.
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. The official
|
||||
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
|
||||
``0.999`` for its pi05_libero config; the openpi PyTorch port
|
||||
explicitly lists EMA as unsupported, and LeRobot main inherited
|
||||
that gap. Enabling this flag plugs ema-pytorch
|
||||
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
|
||||
training loop with a shadow ``nn.Module`` clone of the policy.
|
||||
|
||||
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
|
||||
params) + one elementwise update per training step (~1% step time).
|
||||
|
||||
Off by default (opt-in): EMA is only beneficial for flow-matching /
|
||||
diffusion policies (pi0/pi05/pi052), and the fp32 shadow copy is pure
|
||||
overhead for other policies (e.g. VLA-JEPA). Set ``--ema.enable=true``
|
||||
to turn it on (the pi05/pi052 training recipes do this). openpi (JAX)
|
||||
ships EMA on for every config; enable it explicitly to match that.
|
||||
"""
|
||||
|
||||
enable: bool = False
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
|
||||
# ema-pytorch as ``beta``).
|
||||
# 0.999 — last ~1000 steps; pi05_libero default in openpi
|
||||
# 0.99 — last ~100 steps; openpi top-level default
|
||||
# 0.75 — very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 — very slow EMA (long classification runs)
|
||||
decay: float = 0.99
|
||||
# Skip the first N calls to ``ema.update()``; during this window
|
||||
# the shadow is just a hard copy of the live weights (no averaging).
|
||||
# Lets early-training rapid changes settle before averaging begins.
|
||||
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
|
||||
# ramp like older lerobot EMA implementations).
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses the EMA shadow model
|
||||
# directly (``ema.ema_model``) instead of the live policy. Standard
|
||||
# practice for diffusion-style policies — eval scores are usually
|
||||
# 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses the EMA
|
||||
# shadow for the optional predicted-action columns (so what you see
|
||||
# in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -147,16 +147,7 @@ class TrainingRecipe:
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and the recipe supervises something.
|
||||
|
||||
A recipe is valid if it has at least one of:
|
||||
|
||||
* a ``target: true`` assistant turn (drives text-CE supervision), or
|
||||
* a ``stream: low_level`` turn (drives flow / action supervision via
|
||||
``predict_actions=True``, even when no assistant turn is targeted —
|
||||
e.g. π0.5-style ``low_level_execution`` where the action expert
|
||||
conditions on a user-only ``${subtask}`` prompt).
|
||||
"""
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
@@ -165,14 +156,8 @@ class TrainingRecipe:
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
has_target = any(turn.target for turn in self.messages)
|
||||
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
|
||||
if not (has_target or has_low_level):
|
||||
raise ValueError(
|
||||
"Message recipes must contain at least one supervised turn — "
|
||||
"either ``target: true`` (text CE) or ``stream: low_level`` "
|
||||
"(flow/action loss)."
|
||||
)
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.30
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.55
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into a
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on this dataset.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
@@ -1,99 +0,0 @@
|
||||
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
|
||||
#
|
||||
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
|
||||
# camera-grounded VQA across the three RoboCasa camera keys produced
|
||||
# by ``slurm_build_robocasa_composite_seen.py``:
|
||||
#
|
||||
# observation.images.robot0_agentview_left (left scene view)
|
||||
# observation.images.robot0_agentview_right (right scene view)
|
||||
# observation.images.robot0_eye_in_hand (wrist)
|
||||
#
|
||||
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
|
||||
# VQA per camera, so each anchor frame produces three (user, assistant)
|
||||
# rows tagged with their source camera. Each VQA sub-recipe consumes
|
||||
# the rows for one camera via ``camera=...`` resolver bindings.
|
||||
#
|
||||
# Spatial VQA targets (bbox / point) are rewritten from JSON to
|
||||
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
|
||||
# ``register_paligemma_loc_tokens`` already collapses them to single
|
||||
# detection-vocab ids so the LM head learns the pretrained pointing /
|
||||
# detection prior, not a 7-piece BPE salad.
|
||||
#
|
||||
# Interjections / spoken responses are intentionally absent — the
|
||||
# annotation job runs with ``--interjections.enabled=false``.
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.45
|
||||
messages:
|
||||
# Action expert is conditioned on the SUBTASK; at inference the
|
||||
# high-level loop generates it via the LM head and feeds it here.
|
||||
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
|
||||
# loss fires; subtask CE is owned by ``high_level_subtask``.
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# Trained densely with ``active_at`` — every frame inside a subtask
|
||||
# interval — so the (prior_memory, completed_subtask) → current_memory
|
||||
# mapping is supervised against varied observations. The *when* to
|
||||
# emit lives in the inference trigger (subtask_change), not the
|
||||
# model. See ``subtask_mem.yaml`` for the long version of this note.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
ask_vqa_agentview_left:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_left}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_agentview_right:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_right}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_eye_in_hand}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,114 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into the
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on Hi-Robot-style datasets.
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
# Spoken reply only: the assistant turn carries no text content,
|
||||
# just a `say` tool call (`tool_calls_from: speech`). The chat
|
||||
# tokenizer flattens it to a `<say>...</say>` marker, so the
|
||||
# supervised target trains the model to respond to an
|
||||
# interjection with a spoken acknowledgement.
|
||||
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
|
||||
|
||||
# VQA is view-dependent — each camera gets its own sub-recipe so the
|
||||
# resolver disambiguates via `camera=...`. Camera keys match
|
||||
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,61 +0,0 @@
|
||||
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
|
||||
#
|
||||
# Trains two things only: subtasks and VQA. Plan and memory are
|
||||
# intentionally left out — keeps the prompt short and the training
|
||||
# surface small. The fuller blend with memory + spoken replies is
|
||||
# ``subtask_mem_vqa_speech.yaml``.
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# PI052's text tokenizer renders these messages as plain
|
||||
# ``Role: content`` text (PaliGemma is not chat-pretrained).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.40
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# the high-level loop (``HighLevelSubtaskFwd``) generates the
|
||||
# subtask via the LM head and feeds it here. The action expert's
|
||||
# prefix is [images, subtask, state]. ``stream: low_level`` flips
|
||||
# ``predict_actions=True`` so the flow loss fires; no text-CE
|
||||
# target here (subtask prediction is owned by
|
||||
# ``high_level_subtask``).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -111,20 +111,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||
# carrying a `vqa` language annotation often enough that they make
|
||||
# up roughly this fraction of the training stream. VQA annotations
|
||||
# are typically sparse, so without this they are underrepresented.
|
||||
# `None` (default) keeps uniform episode-aware sampling.
|
||||
vqa_target_fraction: float | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training). Old
|
||||
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
|
||||
# field by ``_migrate_legacy_rabc_keys`` above.
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
|
||||
@@ -35,6 +35,7 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
@@ -49,24 +50,11 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
@@ -77,7 +65,6 @@ __all__ = [
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"WeightedEpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
|
||||
@@ -126,53 +126,10 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
# Datasets annotated with the PR1 language columns may have been
|
||||
# written without registering those columns in ``meta/info.json``
|
||||
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
||||
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
||||
# shard and graft the column features on so the strict
|
||||
# ``Dataset.from_parquet`` cast doesn't fail with
|
||||
# ``column names don't match``.
|
||||
features = self._extend_features_with_language_columns(features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _extend_features_with_language_columns(
|
||||
self, features: datasets.Features
|
||||
) -> datasets.Features:
|
||||
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
||||
when the underlying parquet shards declare them but the metadata
|
||||
doesn't. No-op when neither column is present or both are
|
||||
already registered.
|
||||
"""
|
||||
# Find any one parquet to peek at; bail if there are none yet
|
||||
# (the dataset will fail later for an unrelated reason and we
|
||||
# want that error to surface as-is).
|
||||
try:
|
||||
sample = next((self.root / "data").glob("*/*.parquet"))
|
||||
except StopIteration:
|
||||
return features
|
||||
|
||||
from pyarrow import parquet as _pq # noqa: PLC0415
|
||||
|
||||
schema_names = set(_pq.read_schema(sample).names)
|
||||
from .language import ( # noqa: PLC0415
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
|
||||
extra: dict[str, object] = {}
|
||||
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
||||
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
||||
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
||||
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
||||
if not extra:
|
||||
return features
|
||||
return datasets.Features({**features, **extra})
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
|
||||
@@ -170,29 +170,6 @@ def render_sample(
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
|
||||
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||
# recipe-side training share equal the VQA-annotation density (the
|
||||
# maximum reachable without a dataset-level oversampling sampler).
|
||||
if recipe.blend is not None:
|
||||
vqa_rendered = _render_vqa_if_present(
|
||||
recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
if vqa_rendered is not None:
|
||||
return vqa_rendered
|
||||
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
@@ -206,59 +183,6 @@ def render_sample(
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _render_vqa_if_present(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||
annotation; otherwise return ``None`` so the caller falls back to the
|
||||
normal weighted blend.
|
||||
|
||||
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||
than one camera), one is chosen deterministically by relative weight.
|
||||
"""
|
||||
assert recipe.blend is not None
|
||||
renderable: list[tuple[float, RenderedMessages]] = []
|
||||
for name, component in recipe.blend.items():
|
||||
if not name.startswith("ask_vqa"):
|
||||
continue
|
||||
bindings = _resolve_bindings(
|
||||
component,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
rendered = _render_message_recipe(component, bindings)
|
||||
if rendered is not None:
|
||||
renderable.append((float(component.weight or 0.0), rendered))
|
||||
|
||||
if not renderable:
|
||||
return None
|
||||
if len(renderable) == 1:
|
||||
return renderable[0][1]
|
||||
|
||||
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||
cumulative = 0.0
|
||||
for w, rendered in renderable:
|
||||
cumulative += w or (total / len(renderable))
|
||||
if draw < cumulative:
|
||||
return rendered
|
||||
return renderable[-1][1]
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
@@ -422,15 +346,7 @@ def _render_message_recipe(
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
# A render is meaningful if it supervises *something*: either a
|
||||
# text-CE target turn, or a ``low_level`` stream turn (flow / action
|
||||
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
|
||||
# ``user(${subtask})`` with ``stream: low_level`` and no target).
|
||||
# Without this, a flow-only recipe renders to ``None`` every time
|
||||
# the blend draws it → ``predict_actions`` is never True → the
|
||||
# action expert never receives a flow loss.
|
||||
has_low_level = any(stream == "low_level" for stream in streams)
|
||||
if not target_indices and not has_low_level:
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
@@ -487,10 +403,8 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
# Valid iff it supervises something: a text-CE target turn OR a
|
||||
# ``low_level`` stream turn (flow / action supervision).
|
||||
if not target_indices and not any(s == "low_level" for s in streams):
|
||||
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
|
||||
@@ -84,66 +84,3 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||
proportion to per-frame weights.
|
||||
|
||||
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||
question) so the policy sees them more often than their natural
|
||||
dataset density. One epoch still yields ``len(self.indices)``
|
||||
samples — the weights only change the *composition* of the stream,
|
||||
not its length. Each epoch re-draws, so the oversampled subset
|
||||
varies run to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
frame_weights,
|
||||
*,
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||
dataset_to_indices: Episode end indices.
|
||||
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||
that frame is sampled more often.
|
||||
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||
frame filtering is applied first, then weighting is restricted
|
||||
to the surviving frames.
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_from_indices,
|
||||
dataset_to_indices,
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
drop_n_first_frames=drop_n_first_frames,
|
||||
drop_n_last_frames=drop_n_last_frames,
|
||||
shuffle=False,
|
||||
)
|
||||
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||
if weights.numel() <= int(idx.max()):
|
||||
raise ValueError(
|
||||
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||
f"references frame index {int(idx.max())}."
|
||||
)
|
||||
selected = weights[idx]
|
||||
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||
raise ValueError("frame_weights must be finite and non-negative.")
|
||||
if float(selected.sum()) <= 0.0:
|
||||
# All surviving frames have zero weight — fall back to uniform.
|
||||
selected = torch.ones_like(selected)
|
||||
self._weights = selected
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||
for i in picks.tolist():
|
||||
yield self.indices[i]
|
||||
|
||||
@@ -366,24 +366,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
|
||||
# These correspond to the PandaOmron robot in RoboCasa365.
|
||||
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
@@ -101,15 +101,14 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
|
||||
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
|
||||
"""Split a flat (12,) action vector into a RoboCasa action dict.
|
||||
|
||||
Layout (openpi / robocasa.utils.env_utils.convert_action order):
|
||||
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
"""
|
||||
return {
|
||||
"action.end_effector_position": flat_action[0:3],
|
||||
"action.end_effector_rotation": flat_action[3:6],
|
||||
"action.gripper_close": flat_action[6:7],
|
||||
"action.base_motion": flat_action[7:11],
|
||||
"action.control_mode": flat_action[11:12],
|
||||
"action.base_motion": flat_action[0:4],
|
||||
"action.control_mode": flat_action[4:5],
|
||||
"action.end_effector_position": flat_action[5:8],
|
||||
"action.end_effector_rotation": flat_action[8:11],
|
||||
"action.gripper_close": flat_action[11:12],
|
||||
}
|
||||
|
||||
|
||||
@@ -231,14 +230,12 @@ class RoboCasaEnv(gym.Env):
|
||||
return {"pixels": images}
|
||||
|
||||
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
|
||||
# openpi state order: ee first, then base, then gripper (matches the
|
||||
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
|
||||
agent_pos = np.concatenate(
|
||||
[
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.base_position", np.zeros(3)),
|
||||
raw_obs.get("state.base_rotation", np.zeros(4)),
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.gripper_qpos", np.zeros(2)),
|
||||
],
|
||||
axis=-1,
|
||||
|
||||
@@ -104,8 +104,6 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
foreach: bool | None = None
|
||||
fused: bool | None = None
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
|
||||
@@ -25,7 +25,6 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pi052.configuration_pi052 import PI052Config as PI052Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -50,7 +49,6 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"PI052Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
+16
-146
@@ -63,79 +63,6 @@ from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
|
||||
def _restore_pi052_pretrained_state(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
pretrained_path: str,
|
||||
) -> None:
|
||||
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
|
||||
|
||||
pi052's preprocessor includes steps whose constructor args don't
|
||||
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
|
||||
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
|
||||
fitted-tokenizer path that may not exist at eval time). We rebuild
|
||||
those pipelines fresh from ``config.recipe_path`` and then walk
|
||||
over the saved ``policy_{pre,post}processor.json`` files to find
|
||||
each step's ``state_file`` reference and load the bytes back into
|
||||
the corresponding fresh step. Today that's only the
|
||||
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
|
||||
state quantile stats), but the loop is generic so any future
|
||||
stateful step picks up its blob automatically.
|
||||
|
||||
Pairing is by ``registry_name`` AND position so a benign reorder
|
||||
on the saved side surfaces a warning rather than silently feeding
|
||||
the wrong tensors into the wrong step.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
import logging # noqa: PLC0415
|
||||
from pathlib import Path # noqa: PLC0415
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
|
||||
]:
|
||||
config_path = base / config_filename
|
||||
if not config_path.exists():
|
||||
continue
|
||||
saved = json.loads(config_path.read_text())
|
||||
|
||||
for idx, (saved_step, fresh_step) in enumerate(
|
||||
zip(saved.get("steps", []), pipeline.steps, strict=False)
|
||||
):
|
||||
state_file = saved_step.get("state_file")
|
||||
if not state_file:
|
||||
continue
|
||||
saved_name = saved_step.get("registry_name")
|
||||
fresh_name = getattr(type(fresh_step), "_registry_name", None)
|
||||
if saved_name and fresh_name and saved_name != fresh_name:
|
||||
log.warning(
|
||||
"PI052 state restore: %s step %d registry name mismatch "
|
||||
"(saved=%s, fresh=%s); skipping %s",
|
||||
config_filename, idx, saved_name, fresh_name, state_file,
|
||||
)
|
||||
continue
|
||||
state_path = base / state_file
|
||||
if not state_path.exists():
|
||||
log.warning(
|
||||
"PI052 state restore: %s missing at %s; %s left at fresh init",
|
||||
state_file, base, fresh_name,
|
||||
)
|
||||
continue
|
||||
fresh_step.load_state_dict(load_file(str(state_path)))
|
||||
log.info(
|
||||
"PI052 state restore: loaded %s into %s (step %d)",
|
||||
state_file, fresh_name, idx,
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
@@ -203,10 +130,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi052":
|
||||
from .pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
return PI052Policy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
@@ -255,8 +178,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
|
||||
"pi052", "gaussian_actor", "smolvla", "wall_x", "molmoact2".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x", "molmoact2".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -279,10 +202,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi052":
|
||||
from .pi052.configuration_pi052 import PI052Config
|
||||
|
||||
return PI052Config(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -327,12 +246,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
# Optional: HF Hub repo id of the dataset the policy is being
|
||||
# trained on. Used by policies that auto-fit pieces of their
|
||||
# preprocessing (e.g. pi052's FAST action tokenizer per
|
||||
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
|
||||
# policies fall back to their universal pre-fitted tokenizers.
|
||||
dataset_repo_id: str | None
|
||||
dataset_meta: Any | None
|
||||
|
||||
|
||||
@@ -366,50 +279,23 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
|
||||
# pi052 pipelines don't roundtrip through the saved
|
||||
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
|
||||
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
|
||||
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
|
||||
# FAST tokenizer path. Generic ``from_pretrained`` then dies
|
||||
# with ``RenderMessagesStep.__init__() missing 1 required
|
||||
# positional argument: 'recipe'`` (job 22164494).
|
||||
#
|
||||
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
|
||||
# fresh from ``config.recipe_path`` and transplant the saved
|
||||
# stateful blobs (normalizer stats) from the checkpoint dir.
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
preprocessor, postprocessor = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
# GROOT handles normalization in groot_pack_inputs_v3 step
|
||||
# Need to override both stats AND normalize_min_max since saved config might be empty
|
||||
preprocessor_overrides = {}
|
||||
postprocessor_overrides = {}
|
||||
preprocessor_overrides["groot_pack_inputs_v3"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
}
|
||||
from .groot.processor_groot import make_groot_pre_post_processors_from_pretrained
|
||||
|
||||
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
|
||||
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
|
||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
|
||||
"stats": kwargs.get("dataset_stats"),
|
||||
"normalize_min_max": True,
|
||||
"env_action_dim": env_action_dim,
|
||||
}
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
return make_groot_pre_post_processors_from_pretrained(
|
||||
config=policy_cfg,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_overrides=kwargs.get("preprocessor_overrides"),
|
||||
postprocessor_overrides=kwargs.get("postprocessor_overrides"),
|
||||
preprocessor_config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
postprocessor_config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -483,22 +369,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi052":
|
||||
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
|
||||
# come before the PI05Config isinstance check below (otherwise
|
||||
# pi052 would silently pick up π0.5's processor).
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
|
||||
# enabled — the train loop sets it from ``--dataset.repo_id``.
|
||||
# When ``None``, ``make_pi052_pre_post_processors`` skips
|
||||
# the auto-fit and uses the universal tokenizer.
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from .pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
|
||||
@@ -18,4 +18,12 @@ from .configuration_groot import GrootConfig
|
||||
from .modeling_groot import GrootPolicy
|
||||
from .processor_groot import make_groot_pre_post_processors
|
||||
|
||||
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in {"GR00TN17", "GR00TN17Config"}:
|
||||
from .groot_n1_7 import GR00TN17, GR00TN17Config
|
||||
|
||||
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
b, t = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -42,6 +43,9 @@ else:
|
||||
Timesteps = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
require_package("diffusers", extra="groot")
|
||||
@@ -181,8 +185,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
@@ -266,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
logger.debug(
|
||||
"Total number of DiT parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
@@ -318,6 +321,71 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class AlternateVLDiT(DiT):
|
||||
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
|
||||
|
||||
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.attend_text_every_n_blocks = attend_text_every_n_blocks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
image_mask: torch.Tensor | None = None,
|
||||
backbone_attention_mask: torch.Tensor | None = None,
|
||||
):
|
||||
if image_mask is None:
|
||||
raise ValueError("image_mask is required for AlternateVLDiT.")
|
||||
if backbone_attention_mask is None:
|
||||
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
|
||||
|
||||
temb = self.timestep_encoder(timestep)
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
image_attention_mask = image_mask & backbone_attention_mask
|
||||
non_image_attention_mask = (~image_mask) & backbone_attention_mask
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
if not self.config.interleave_self_attention:
|
||||
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
|
||||
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
curr_encoder_attention_mask = (
|
||||
non_image_attention_mask
|
||||
if idx % (2 * self.attend_text_every_n_blocks) == 0
|
||||
else image_attention_mask
|
||||
)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=curr_encoder_attention_mask,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@@ -362,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
logger.debug(
|
||||
"Total number of SelfAttentionTransformer parameters: %d",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,408 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
PretrainedConfig = object
|
||||
BatchFeature = None
|
||||
|
||||
from .action_encoder import (
|
||||
SinusoidalPositionalEncoding,
|
||||
swish,
|
||||
)
|
||||
from .cross_attention_dit import DiT, SelfAttentionTransformer
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
b, t, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == b:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, t)
|
||||
else:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
|
||||
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
|
||||
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
|
||||
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
|
||||
backbone_embedding_dim: int = field(
|
||||
default=1536, metadata={"help": "Backbone embedding channel dimension."}
|
||||
)
|
||||
|
||||
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
|
||||
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
|
||||
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
|
||||
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
|
||||
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
|
||||
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
|
||||
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
|
||||
num_timestep_buckets: int = field(
|
||||
default=1000, metadata={"help": "Number of timestep discretization buckets."}
|
||||
)
|
||||
num_inference_timesteps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of inference steps for noise diffusion."},
|
||||
)
|
||||
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
|
||||
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
|
||||
tune_diffusion_model: bool = field(
|
||||
default=True, metadata={"help": "Whether to tune the diffusion model."}
|
||||
)
|
||||
load_pretrained_det_decode_layer_path: str = field(
|
||||
default=None, metadata={"help": "Path to pretrained detection model."}
|
||||
)
|
||||
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
|
||||
|
||||
freeze_decode_layer: bool = field(default=False)
|
||||
expand_batch: int = field(default=None)
|
||||
use_vlln: bool = field(default=True)
|
||||
|
||||
vl_self_attention_cfg: dict = field(default=None)
|
||||
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
config_class = FlowmatchingActionHeadConfig
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FlowmatchingActionHeadConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
self.model = DiT(**config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
|
||||
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
|
||||
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
self.vl_self_attention = (
|
||||
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
|
||||
)
|
||||
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.config = config
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
|
||||
|
||||
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
print(f"Tune action head projector: {self.tune_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_projector and not tune_diffusion_model:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Action head trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
if self._beta_dist is None:
|
||||
self._beta_dist = Beta(self._noise_beta_alpha, self._noise_beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = backbone_output["backbone_features"]
|
||||
backbone_features = self.vlln(backbone_features)
|
||||
backbone_features = self.vl_self_attention(backbone_features)
|
||||
backbone_output["backbone_features"] = backbone_features
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
# Set frozen modules to eval
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
if self.config.expand_batch is not None:
|
||||
for k, v in backbone_output.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
backbone_output[k] = expanded
|
||||
|
||||
for k, v in action_input.items():
|
||||
ndim = len(v.shape)
|
||||
factors = [self.config.expand_batch]
|
||||
while len(factors) < ndim:
|
||||
factors.append(1)
|
||||
factors = tuple(factors)
|
||||
expanded = v.repeat(*factors)
|
||||
action_input[k] = expanded
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
device = vl_embs.device
|
||||
|
||||
# Get embodiment ID.
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
vl_attn_mask = backbone_output.backbone_attention_mask
|
||||
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=vl_attn_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=False, # NOTE (YL): not using flare now
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Slice out only the action portion of pred and target.
|
||||
action_mask = action_input.action_mask
|
||||
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = loss.sum() / action_mask.sum()
|
||||
output_dict = {
|
||||
"loss": loss,
|
||||
}
|
||||
return BatchFeature(data=output_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
|
||||
# Get vision and language embeddings.
|
||||
vl_embs = backbone_output.backbone_features
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
# Embed state.
|
||||
state_features = self.state_encoder(action_input.state, embodiment_id)
|
||||
|
||||
# Set initial actions as the sampled noise.
|
||||
batch_size = vl_embs.shape[0]
|
||||
device = vl_embs.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=vl_embs.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# Run denoising steps.
|
||||
for t in range(num_steps):
|
||||
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# Embed noised action trajectory.
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
# Maybe add position embedding.
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
|
||||
action_features = action_features + pos_embs
|
||||
|
||||
# Join vision, language, state and action embedding along sequence dimension.
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
|
||||
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
|
||||
|
||||
# Run model forward.
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
|
||||
pred_velocity = pred[:, -self.action_horizon :]
|
||||
|
||||
# Update actions using euler integration.
|
||||
actions = actions + dt * pred_velocity
|
||||
return BatchFeature(data={"action_pred": actions})
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -14,12 +14,327 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||
GROOT_N1_5 = "n1.5"
|
||||
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
"GR00T N1.5 support was removed from LeRobot. "
|
||||
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||
|
||||
_GROOT_MODEL_VERSION_ALIASES = {
|
||||
"n1.7": GROOT_N1_7,
|
||||
"n1_7": GROOT_N1_7,
|
||||
"n1d7": GROOT_N1_7,
|
||||
"n17": GROOT_N1_7,
|
||||
"1.7": GROOT_N1_7,
|
||||
}
|
||||
|
||||
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||
|
||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||
"none": None,
|
||||
"": None,
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||
}
|
||||
|
||||
|
||||
def normalize_groot_model_version(model_version: str) -> str:
|
||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||
if normalized is None:
|
||||
supported = GROOT_N1_7
|
||||
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
|
||||
if transform is None:
|
||||
return None
|
||||
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
|
||||
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
|
||||
supported = ", ".join(
|
||||
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
|
||||
f"Supported transforms: none, {supported}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def infer_groot_model_version(model_path: str | None) -> str | None:
|
||||
if not model_path:
|
||||
return None
|
||||
model_path_lower = model_path.lower()
|
||||
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
|
||||
return GROOT_N1_7
|
||||
# Detect legacy N1.5 paths so the N1.7 config/loader can reject the mismatch.
|
||||
# N1.5 is unsupported, but it must still be recognised here to fail loudly
|
||||
# rather than silently treating an N1.5 checkpoint as N1.7.
|
||||
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
|
||||
return GROOT_N1_5
|
||||
config_version = _infer_groot_model_version_from_local_config(model_path)
|
||||
if config_version is not None:
|
||||
return config_version
|
||||
return None
|
||||
|
||||
|
||||
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
if model_path is None:
|
||||
return False
|
||||
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
if "libero_sim" in modality_configs:
|
||||
return "libero_sim"
|
||||
if len(modality_configs) == 1:
|
||||
return next(iter(modality_configs))
|
||||
return None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
if model_path is None:
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
return None
|
||||
modality_configs = processor_kwargs.get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag is None:
|
||||
return None
|
||||
|
||||
embodiment_config = modality_configs.get(embodiment_tag, {})
|
||||
if not isinstance(embodiment_config, dict):
|
||||
return None
|
||||
action_config = embodiment_config.get("action", {})
|
||||
if not isinstance(action_config, dict):
|
||||
return None
|
||||
delta_indices = action_config.get("delta_indices", [])
|
||||
if not isinstance(delta_indices, list):
|
||||
return None
|
||||
return len(delta_indices) or None
|
||||
|
||||
|
||||
def infer_groot_n1_7_action_execution_horizon(
|
||||
model_path: str | Path | None, embodiment_tag: str | None = None
|
||||
) -> int | None:
|
||||
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
|
||||
if action_horizon is None:
|
||||
return None
|
||||
|
||||
if embodiment_tag is None:
|
||||
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
|
||||
if embodiment_tag == "libero_sim":
|
||||
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
|
||||
# actions. Keeping that execution cadence avoids stale open-loop chunks.
|
||||
return min(action_horizon, 8)
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
config_path = path / "config.json"
|
||||
elif path.name == "config.json":
|
||||
config_path = path
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
model_version = config.get("model_version")
|
||||
if isinstance(model_version, str):
|
||||
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||
return GROOT_N1_5
|
||||
try:
|
||||
return normalize_groot_model_version(model_version)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
|
||||
for candidate in candidates:
|
||||
if not isinstance(candidate, str):
|
||||
continue
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
return GROOT_N1_7
|
||||
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||
backbone_cfg = config.get("backbone_cfg")
|
||||
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||
return GROOT_N1_5
|
||||
return None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("groot")
|
||||
@dataclass
|
||||
@@ -28,35 +343,44 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Basic policy settings
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
chunk_size: int = 40
|
||||
n_action_steps: int = 40
|
||||
|
||||
# Dimension settings (must match pretrained GR00T model expectations)
|
||||
# Maximum state dimension. Shorter states will be zero-padded.
|
||||
max_state_dim: int = 64
|
||||
max_state_dim: int = 132
|
||||
|
||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||
max_action_dim: int = 32
|
||||
max_action_dim: int = 132
|
||||
|
||||
# Normalization (start with identity, adjust as needed)
|
||||
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Image preprocessing (adjust to match Groot's expected input)
|
||||
image_size: tuple[int, int] = (224, 224)
|
||||
# Groot-specific model parameters
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
base_model_path: str = "nvidia/GR00T-N1.5-3B"
|
||||
base_model_path: str | None = None
|
||||
|
||||
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
|
||||
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
@@ -96,17 +420,16 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Dataset parameters
|
||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
video_backend: str = "decord"
|
||||
|
||||
# Whether to balance dataset weights in mixture datasets
|
||||
balance_dataset_weights: bool = True
|
||||
|
||||
# Whether to sample trajectories weighted by their length
|
||||
balance_trajectory_weights: bool = True
|
||||
|
||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
||||
dataset_paths: list[str] | None = None
|
||||
output_dir: str = "./tmp/gr00t"
|
||||
save_steps: int = 1000
|
||||
@@ -117,6 +440,66 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
|
||||
# The N1.7 LIBERO checkpoints emit a [0, 1] gripper action, but the LIBERO
|
||||
# simulator expects the OpenVLA/[-1, 1] sign convention. NVIDIA's rollout
|
||||
# wrapper applies this conversion; mirror it here so eval on the
|
||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||
# This matches the embodiment-specific handling already done for the
|
||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||
# 'none' (normalized to None above) keeps the transform disabled.
|
||||
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||
self.action_decode_transform = (
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||
)
|
||||
|
||||
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||
# GrootConfig() never triggers this.
|
||||
legacy_default_remaps = (
|
||||
("max_state_dim", 64, 132),
|
||||
("max_action_dim", 32, 132),
|
||||
("chunk_size", 50, 40),
|
||||
("n_action_steps", 50, 40),
|
||||
("image_size", (224, 224), (256, 256)),
|
||||
)
|
||||
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||
current_value = getattr(self, field_name)
|
||||
if isinstance(legacy_value, tuple):
|
||||
current_value = tuple(current_value)
|
||||
if current_value == legacy_value:
|
||||
logger.warning(
|
||||
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||
"if this is not what you want.",
|
||||
field_name,
|
||||
legacy_value,
|
||||
n1_7_value,
|
||||
)
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
@@ -192,7 +575,10 @@ class GrootConfig(PreTrainedConfig):
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions."""
|
||||
return list(range(min(self.chunk_size, 16)))
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
||||
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Eagle25VLConfig(PretrainedConfig):
|
||||
model_type = "eagle_2_5_vl"
|
||||
is_composition = True
|
||||
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-4,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
loss_version="v1",
|
||||
min_dynamic_tiles=1,
|
||||
max_dynamic_tiles=6,
|
||||
mlp_checkpoint=False,
|
||||
initializer_range=0.02,
|
||||
_attn_implementation="flash_attention_2",
|
||||
_attn_implementation_autoset=False,
|
||||
llm_config=None,
|
||||
image_token_index=None,
|
||||
use_pixel_shuffle=True,
|
||||
mlp_connector_layers=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"model_type": "siglip_vision_model"}
|
||||
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {"architectures": ["Qwen2ForCausalLM"]}
|
||||
logger.info(
|
||||
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
|
||||
if vision_config["model_type"] == "siglip_vision_model":
|
||||
self.vision_config = SiglipVisionConfig(**vision_config)
|
||||
else:
|
||||
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
|
||||
|
||||
if text_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.text_config = LlamaConfig(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.text_config = Qwen2Config(**text_config)
|
||||
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
|
||||
self.text_config = Qwen3Config(**text_config)
|
||||
else:
|
||||
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.mlp_checkpoint = mlp_checkpoint
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.loss_version = loss_version
|
||||
self.initializer_range = initializer_range
|
||||
self.min_dynamic_tiles = min_dynamic_tiles
|
||||
self.max_dynamic_tiles = max_dynamic_tiles
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self._attn_implementation = _attn_implementation
|
||||
self._attn_implementation_autoset = _attn_implementation_autoset
|
||||
self.image_token_index = image_token_index
|
||||
self.use_pixel_shuffle = use_pixel_shuffle
|
||||
self.mlp_connector_layers = mlp_connector_layers
|
||||
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
|
||||
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["pad2square"] = self.pad2square
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["min_dynamic_tiles"] = self.min_dynamic_tiles
|
||||
output["max_dynamic_tiles"] = self.max_dynamic_tiles
|
||||
output["tie_word_embeddings"] = self.tie_word_embeddings
|
||||
output["_attn_implementation"] = self._attn_implementation
|
||||
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
|
||||
output["use_pixel_shuffle"] = self.use_pixel_shuffle
|
||||
output["mlp_connector_layers"] = self.mlp_connector_layers
|
||||
return output
|
||||
@@ -1,503 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
|
||||
from transformers.image_processing_utils import (
|
||||
BatchFeature,
|
||||
get_patch_output_size,
|
||||
)
|
||||
from transformers.image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
ImagesKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
|
||||
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
validate_kwargs,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_v2_available,
|
||||
)
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from transformers.image_utils import pil_torch_interpolation_mapping
|
||||
else:
|
||||
from torchvision.transforms import functional as F # noqa: N812
|
||||
|
||||
|
||||
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
|
||||
"""Crop the given numpy array.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
|
||||
left (int): The left coordinate of the crop box.
|
||||
top (int): The top coordinate of the crop box.
|
||||
right (int): The right coordinate of the crop box.
|
||||
bottom (int): The bottom coordinate of the crop box.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cropped image.
|
||||
"""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
|
||||
|
||||
if img.ndim not in [2, 3]:
|
||||
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
|
||||
|
||||
img_height = img.shape[1]
|
||||
img_width = img.shape[2]
|
||||
if top < 0 or left < 0 or bottom > img_height or right > img_width:
|
||||
raise ValueError("Crop coordinates out of bounds")
|
||||
|
||||
if top >= bottom or left >= right:
|
||||
raise ValueError("Invalid crop coordinates")
|
||||
|
||||
return img[:, top:bottom, left:right]
|
||||
|
||||
|
||||
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
|
||||
max_dynamic_tiles: int | None
|
||||
min_dynamic_tiles: int | None
|
||||
use_thumbnail: bool | None
|
||||
pad_during_tiling: bool | None
|
||||
do_pad: bool | None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method. Not used for processing videos.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_pad = True
|
||||
max_dynamic_tiles = 12
|
||||
min_dynamic_tiles = 1
|
||||
use_thumbnail = True
|
||||
pad_during_tiling = False
|
||||
valid_kwargs = Eagle25VLFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@add_start_docstrings(
|
||||
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
|
||||
"""
|
||||
max_dynamic_tiles (`int`, *optional*):
|
||||
The maximum number of dynamic tiles to use for processing high resolution images.
|
||||
min_dynamic_tiles (`int`, *optional*):
|
||||
The minimum number of dynamic tiles to use for processing high resolution images.
|
||||
use_thumbnail (`bool`, *optional*):
|
||||
Whether to use a thumbnail for processing high resolution images.
|
||||
pad_during_tiling (`bool`, *optional*):
|
||||
Whether to pad the image during tiling.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
""",
|
||||
)
|
||||
|
||||
# NOTE(YL): we will overload the preprocess method to add the image_flags
|
||||
# def preprocess(
|
||||
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
|
||||
# ) -> BatchFeature:
|
||||
# return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
self,
|
||||
images: ImageInput,
|
||||
expected_ndims: int = 3,
|
||||
) -> ImageInput:
|
||||
"""
|
||||
Prepare the images structure for processing.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
The input images to process.
|
||||
expected_ndims (`int`, *optional*, defaults to 3):
|
||||
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
|
||||
|
||||
Returns:
|
||||
`ImageInput`: The images with a valid nesting.
|
||||
"""
|
||||
return make_flat_list_of_images(images)
|
||||
|
||||
def _resize_for_patching(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_resolution: tuple,
|
||||
interpolation: F.InterpolationMode,
|
||||
input_data_format: ChannelDimension,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Resizes an image to a target resolution while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image ("torch.Tensor"):
|
||||
The input image.
|
||||
target_resolution (tuple):
|
||||
The target resolution (height, width) of the image.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
input_data_format (`ChannelDimension` or `str`):
|
||||
The channel dimension format of the input image.
|
||||
|
||||
Returns:
|
||||
"torch.Tensor": The resized and padded image.
|
||||
"""
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
# Resize the image
|
||||
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
|
||||
|
||||
return resized_image
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
|
||||
"""
|
||||
previous version mainly focus on ratio.
|
||||
We also consider area ratio here.
|
||||
"""
|
||||
best_factor = float("-inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
|
||||
"""
|
||||
new area > 60% of original image area is enough.
|
||||
"""
|
||||
factor_based_on_area_n_ratio = min(
|
||||
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
|
||||
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
|
||||
|
||||
if factor_based_on_area_n_ratio > best_factor:
|
||||
best_factor = factor_based_on_area_n_ratio
|
||||
best_ratio = ratio
|
||||
|
||||
return best_ratio
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: torch.Tensor, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
|
||||
paste_x = (target_width - new_width) // 2
|
||||
paste_y = (target_height - new_height) // 2
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
|
||||
|
||||
return padded_image
|
||||
|
||||
def _get_image_patches(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
size: tuple,
|
||||
tile_size: int,
|
||||
use_thumbnail: bool,
|
||||
interpolation: F.InterpolationMode,
|
||||
pad_during_tiling: bool,
|
||||
) -> list[torch.Tensor]:
|
||||
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = tile_size * target_aspect_ratio[0]
|
||||
target_height = tile_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if pad_during_tiling:
|
||||
resized_image = self._resize_for_patching(
|
||||
image,
|
||||
(target_height, target_width),
|
||||
interpolation=interpolation,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
padded_image = self._pad_for_patching(
|
||||
resized_image,
|
||||
(target_height, target_width),
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
)
|
||||
image_used_to_split = padded_image
|
||||
else:
|
||||
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
|
||||
|
||||
processed_tiles = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // tile_size)) * tile_size,
|
||||
(i // (target_width // tile_size)) * tile_size,
|
||||
((i % (target_width // tile_size)) + 1) * tile_size,
|
||||
((i // (target_width // tile_size)) + 1) * tile_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
|
||||
processed_tiles.append(split_img)
|
||||
assert len(processed_tiles) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_tiles) != 1:
|
||||
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
|
||||
processed_tiles.append(thumbnail_img)
|
||||
|
||||
return processed_tiles
|
||||
|
||||
def _pad_for_batching(
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
|
||||
|
||||
Args:
|
||||
pixel_values (`List[torch.Tensor]`):
|
||||
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
|
||||
|
||||
Returns:
|
||||
List[`torch.Tensor`]: The padded images.
|
||||
"""
|
||||
max_patch = max(len(x) for x in pixel_values)
|
||||
pixel_values = [
|
||||
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
|
||||
for image in pixel_values
|
||||
]
|
||||
|
||||
return pixel_values
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
max_dynamic_tiles: int,
|
||||
min_dynamic_tiles: int,
|
||||
use_thumbnail: bool,
|
||||
pad_during_tiling: bool,
|
||||
interpolation: F.InterpolationMode | None,
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: float | list[float] | None,
|
||||
image_std: float | list[float] | None,
|
||||
do_pad: bool,
|
||||
return_tensors: str | TensorType | None,
|
||||
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
tile_size = crop_size.height
|
||||
elif size and size.height:
|
||||
tile_size = size.height
|
||||
else:
|
||||
tile_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
min_num=min_dynamic_tiles,
|
||||
max_num=max_dynamic_tiles,
|
||||
size=size_tuple,
|
||||
tile_size=tile_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
interpolation=interpolation,
|
||||
pad_during_tiling=pad_during_tiling,
|
||||
)
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
|
||||
image_patches,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(
|
||||
processed_image_patches_grouped, grouped_image_patches_index
|
||||
)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
|
||||
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
validate_kwargs(
|
||||
captured_kwargs=kwargs.keys(),
|
||||
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
|
||||
)
|
||||
# Set default kwargs from self. This ensures that if a kwarg is not provided
|
||||
# by the user, it gets its default value from the instance, or is set to None.
|
||||
for kwarg_name in self.valid_kwargs.__annotations__:
|
||||
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
|
||||
|
||||
# Extract parameters that are only used for preparing the input images
|
||||
do_convert_rgb = kwargs.pop("do_convert_rgb")
|
||||
input_data_format = kwargs.pop("input_data_format")
|
||||
device = kwargs.pop("device")
|
||||
# Prepare input images
|
||||
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
|
||||
if images is not None:
|
||||
images = self._prepare_image_like_inputs(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
videos = self._prepare_image_like_inputs(
|
||||
images=videos,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Update kwargs that need further processing before being validated
|
||||
kwargs = self._further_process_kwargs(**kwargs)
|
||||
|
||||
# Validate kwargs
|
||||
self._validate_preprocess_kwargs(**kwargs)
|
||||
|
||||
# torch resize uses interpolation instead of resample
|
||||
# Added for transformers >=4.53.0 compatibility
|
||||
resample = kwargs.pop("resample", self.resample)
|
||||
kwargs["interpolation"] = (
|
||||
pil_torch_interpolation_mapping[resample]
|
||||
if isinstance(resample, PILImageResampling | int)
|
||||
else resample
|
||||
)
|
||||
|
||||
# Filter kwargs to only include those accepted by _preprocess
|
||||
valid_preprocess_kwargs = {
|
||||
"do_resize",
|
||||
"size",
|
||||
"max_dynamic_tiles",
|
||||
"min_dynamic_tiles",
|
||||
"use_thumbnail",
|
||||
"pad_during_tiling",
|
||||
"interpolation",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_pad",
|
||||
"return_tensors",
|
||||
"pad_size",
|
||||
"disable_grouping",
|
||||
}
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
|
||||
if images is not None:
|
||||
return self._preprocess(images, **filtered_kwargs)
|
||||
elif videos is not None:
|
||||
return self._preprocess(videos, **filtered_kwargs)
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLImageProcessorFast"]
|
||||
@@ -1,396 +0,0 @@
|
||||
# --------------------------------------------------------
|
||||
# NVIDIA
|
||||
# Copyright (c) 2025 NVIDIA
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as cp
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import GenerationConfig
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
||||
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
||||
from transformers.utils import add_start_docstrings, logging
|
||||
|
||||
from .configuration_eagle2_5_vl import Eagle25VLConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
|
||||
EAGLE2_5_VL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Eagle25VLConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
|
||||
EAGLE2_5_VL_START_DOCSTRING,
|
||||
)
|
||||
class Eagle25VLPreTrainedModel(PreTrainedModel):
|
||||
config_class = Eagle25VLConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Qwen2DecoderLayer",
|
||||
"LlamaDecoderLayer",
|
||||
"Siglip2EncoderLayer",
|
||||
"SiglipEncoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear | nn.Conv2d):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
|
||||
config_class = Eagle25VLConfig
|
||||
|
||||
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
|
||||
super().__init__(config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
if config.use_pixel_shuffle:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
|
||||
else:
|
||||
self.num_image_token = int((image_size // patch_size) ** 2)
|
||||
|
||||
self.select_layer = config.select_layer
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.loss_version = config.loss_version
|
||||
self.mlp_checkpoint = config.mlp_checkpoint
|
||||
self.use_pixel_shuffle = config.use_pixel_shuffle
|
||||
self.mlp_connector_layers = config.mlp_connector_layers
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
|
||||
if vision_model is not None:
|
||||
self.vision_model = vision_model
|
||||
else:
|
||||
if config.vision_config.model_type == "siglip_vision_model":
|
||||
config.vision_config._attn_implementation = "flash_attention_2"
|
||||
self.vision_model = SiglipVisionModel(config.vision_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
|
||||
|
||||
if language_model is not None:
|
||||
self.language_model = language_model
|
||||
else:
|
||||
if config.text_config.architectures[0] == "LlamaForCausalLM":
|
||||
self.language_model = LlamaForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
|
||||
raise NotImplementedError("Phi3 is not implemented.")
|
||||
# self.language_model = Phi3ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
assert config.text_config._attn_implementation == "flash_attention_2", (
|
||||
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
|
||||
)
|
||||
self.language_model = Qwen2ForCausalLM(config.text_config)
|
||||
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(config.text_config)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
if config.mlp_connector_layers == 2:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
||||
)
|
||||
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.Linear(vit_hidden_size, llm_hidden_size),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
|
||||
|
||||
self.image_token_index = config.image_token_index
|
||||
self.neftune_alpha = None
|
||||
|
||||
if config.use_backbone_lora:
|
||||
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
|
||||
|
||||
self.use_llm_lora = config.use_llm_lora
|
||||
if config.use_llm_lora:
|
||||
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
|
||||
|
||||
self.check_forward_kwargs()
|
||||
|
||||
def check_forward_kwargs(self):
|
||||
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
|
||||
# has special handling for functions with **kwargs parameters that would affect
|
||||
# how our model is processed during training and inference.
|
||||
forward_params = inspect.signature(self.forward).parameters
|
||||
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
|
||||
|
||||
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.out_proj",
|
||||
"mlp.fc1",
|
||||
"mlp.fc2",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
self.vision_model.print_trainable_parameters()
|
||||
|
||||
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=[
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.up_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.language_model = get_peft_model(self.language_model, lora_config)
|
||||
self.language_model.enable_input_require_grads()
|
||||
self.language_model.print_trainable_parameters()
|
||||
self.use_llm_lora = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
image_flags: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_tiles_list: list[torch.Tensor] | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
if image_flags is not None:
|
||||
image_flags = image_flags.view(-1)
|
||||
vit_embeds = vit_embeds[image_flags == 1]
|
||||
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.image_token_index
|
||||
try:
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
|
||||
except Exception as e:
|
||||
vit_embeds = vit_embeds.reshape(-1, c)
|
||||
print(
|
||||
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
|
||||
f"vit_embeds.shape={vit_embeds.shape}"
|
||||
)
|
||||
n_token = selected.sum()
|
||||
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
logits = outputs.logits
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
|
||||
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
)
|
||||
if hasattr(vit_embeds, "last_hidden_state"):
|
||||
vit_embeds = vit_embeds.last_hidden_state
|
||||
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
|
||||
if self.use_pixel_shuffle:
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
|
||||
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
|
||||
|
||||
if self.mlp_checkpoint and vit_embeds.requires_grad:
|
||||
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
|
||||
else:
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
|
||||
return vit_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
input_ids: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
visual_features: torch.FloatTensor | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
image_sizes: list[tuple[int, int]] | None = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
if pixel_values is not None:
|
||||
if visual_features is not None:
|
||||
vit_embeds = visual_features
|
||||
else:
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
b, n, c = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(b * n, c)
|
||||
|
||||
input_ids = input_ids.reshape(b * n)
|
||||
selected = input_ids == self.config.image_token_index
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(b, n, c)
|
||||
else:
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if "use_cache" not in generate_kwargs:
|
||||
generate_kwargs["use_cache"] = True
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=generation_config,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
@@ -1,541 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Eagle25VL.
|
||||
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FRAME_FACTOR = 2
|
||||
FPS = 2.0
|
||||
FPS_MIN_FRAMES = 4
|
||||
FPS_MAX_FRAMES = 256
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == "RGBA":
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
|
||||
image = ele["image"] if "image" in ele else ele["image_url"]
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
response = requests.get(image, stream=True, timeout=10)
|
||||
image_obj = Image.open(BytesIO(response.content))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(
|
||||
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
||||
)
|
||||
image = to_rgb(image_obj)
|
||||
if "scale_factor" in ele:
|
||||
scale_factor = ele["scale_factor"]
|
||||
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
|
||||
return image
|
||||
|
||||
|
||||
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
# see processing_utils.ProcessingKwargs documentation for usage.
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
"videos_kwargs": {"max_dynamic_tiles": 1},
|
||||
}
|
||||
|
||||
|
||||
class Eagle25VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
|
||||
|
||||
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
|
||||
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
num_image_tokens (`int`, *optional*):
|
||||
Number of image tokens for one imagethat will be returned by vision tower.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Should be same as in model's config
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
video_token (`str`, *optional*, defaults to `"<video>"`):
|
||||
Special token used to denote video location.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"num_image_tokens",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"text_kwargs",
|
||||
]
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<IMG_CONTEXT>", # nosec: B107
|
||||
video_token="<IMG_CONTEXT>", # nosec: B107
|
||||
tokens_per_tile=256,
|
||||
image_placeholder="image",
|
||||
video_placeholder="video",
|
||||
image_start_token="<img>",
|
||||
image_end_token="</img>",
|
||||
**kwargs,
|
||||
):
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
self.image_placeholder = image_placeholder
|
||||
self.video_placeholder = video_placeholder
|
||||
self.tokens_per_tile = tokens_per_tile
|
||||
self.image_start_token = image_start_token
|
||||
self.image_end_token = image_end_token
|
||||
if "auto_map" in kwargs:
|
||||
self.auto_map = kwargs["auto_map"]
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def replace_media_placeholder(
|
||||
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
|
||||
):
|
||||
num_of_images_in_this_sample = 0
|
||||
num_of_videos_in_this_sample = 0
|
||||
# Regular expression pattern to match formats like <image-1> or <video-2>
|
||||
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
|
||||
unified_frame_list = []
|
||||
|
||||
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
# )
|
||||
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
|
||||
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
# )
|
||||
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
|
||||
# "use_thumbnail", self.image_processor.use_thumbnail
|
||||
# )
|
||||
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
|
||||
)
|
||||
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
|
||||
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
|
||||
)
|
||||
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
|
||||
"use_thumbnail", self.image_processor.use_thumbnail
|
||||
)
|
||||
|
||||
tile_size = self.image_processor.size.get("height", 448)
|
||||
|
||||
# Function to replace tags in a single text
|
||||
def replace_in_text(text):
|
||||
# repl callback function for each match replacement operation
|
||||
def repl(match):
|
||||
nonlocal unified_frame_list
|
||||
nonlocal num_of_images_in_this_sample
|
||||
nonlocal num_of_videos_in_this_sample
|
||||
media_type = match.group(1) # 'image' or 'video'
|
||||
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
|
||||
# Select the corresponding path based on media type
|
||||
idx_mapper = {
|
||||
0: "first",
|
||||
1: "second",
|
||||
2: "third",
|
||||
3: "fourth",
|
||||
4: "fifth",
|
||||
5: "sixth",
|
||||
6: "seventh",
|
||||
7: "eighth",
|
||||
8: "ninth",
|
||||
9: "tenth",
|
||||
}
|
||||
if media_type == "image":
|
||||
image_inputs = self.image_processor(
|
||||
images=[image_list[idx_in_list]],
|
||||
videos=None,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
if isinstance(image_inputs["pixel_values"], list):
|
||||
_pv = image_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
image_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = image_inputs["pixel_values"].shape[0]
|
||||
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
unified_frame_list.append(image_inputs)
|
||||
num_of_images_in_this_sample += 1
|
||||
|
||||
elif media_type == "video":
|
||||
video_inputs = self.image_processor(
|
||||
images=None,
|
||||
videos=[video_list[idx_in_list]],
|
||||
**output_kwargs["videos_kwargs"],
|
||||
)
|
||||
if isinstance(video_inputs["pixel_values"], list):
|
||||
_pv = video_inputs["pixel_values"]
|
||||
if _pv and isinstance(_pv[0], list):
|
||||
_pv = [t for sub in _pv for t in sub]
|
||||
video_inputs["pixel_values"] = torch.stack(
|
||||
[t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in _pv]
|
||||
)
|
||||
num_all_tiles = video_inputs["pixel_values"].shape[0]
|
||||
image_sizes = video_inputs["image_sizes"]
|
||||
if timestamps_list is not None and -1 not in timestamps_list:
|
||||
frame_timestamps = timestamps_list[idx_in_list]
|
||||
else:
|
||||
frame_timestamps = None
|
||||
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
|
||||
|
||||
num_of_tiles_each_frame = [
|
||||
self.get_number_tiles_based_on_image_size(
|
||||
image_size,
|
||||
video_min_dynamic_tiles,
|
||||
video_max_dynamic_tiles,
|
||||
video_use_thumbnail,
|
||||
tile_size,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
|
||||
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
|
||||
)
|
||||
|
||||
if frame_timestamps is not None:
|
||||
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
|
||||
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
|
||||
)
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
else:
|
||||
special_placeholder = [
|
||||
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
|
||||
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
|
||||
]
|
||||
|
||||
if sampled_fps is not None:
|
||||
special_placeholder = (
|
||||
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
|
||||
+ "".join(special_placeholder)
|
||||
)
|
||||
else:
|
||||
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
|
||||
special_placeholder
|
||||
)
|
||||
unified_frame_list.append(video_inputs)
|
||||
num_of_videos_in_this_sample += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown media type: {media_type}")
|
||||
return special_placeholder
|
||||
|
||||
return pattern.sub(repl, text)
|
||||
|
||||
text = replace_in_text(text)
|
||||
if len(unified_frame_list) > 0:
|
||||
|
||||
def _to_tensor(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
if v and isinstance(v[0], list):
|
||||
v = [t for sub in v for t in sub]
|
||||
return torch.stack([t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in v])
|
||||
return torch.as_tensor(v)
|
||||
|
||||
pixel_values = torch.cat([_to_tensor(frame["pixel_values"]) for frame in unified_frame_list])
|
||||
image_sizes = torch.cat([_to_tensor(frame["image_sizes"]) for frame in unified_frame_list])
|
||||
else:
|
||||
pixel_values = None
|
||||
image_sizes = None
|
||||
return (
|
||||
text,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Eagle25VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Eagle25VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(text, str):
|
||||
text_list = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
elif isinstance(text, list) and isinstance(text[0], str):
|
||||
text_list = text
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if videos is None:
|
||||
videos = []
|
||||
|
||||
pixel_values_list = []
|
||||
image_sizes_list = []
|
||||
new_sample_list = []
|
||||
image_start_idx = 0
|
||||
video_start_idx = 0
|
||||
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
|
||||
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
|
||||
for sample in text_list:
|
||||
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
|
||||
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
|
||||
(
|
||||
sample,
|
||||
pixel_values,
|
||||
image_sizes,
|
||||
num_of_images_in_this_sample,
|
||||
num_of_videos_in_this_sample,
|
||||
) = self.replace_media_placeholder(
|
||||
sample,
|
||||
images[image_start_idx:],
|
||||
videos[video_start_idx:],
|
||||
timestamps_list,
|
||||
fps_list,
|
||||
**output_kwargs,
|
||||
)
|
||||
new_sample_list.append(sample)
|
||||
if pixel_values is not None:
|
||||
pixel_values_list.append(pixel_values)
|
||||
image_sizes_list.append(image_sizes)
|
||||
image_start_idx += num_of_images_in_this_sample
|
||||
video_start_idx += num_of_videos_in_this_sample
|
||||
|
||||
if len(pixel_values_list) > 0:
|
||||
image_inputs = {
|
||||
"pixel_values": torch.cat(pixel_values_list),
|
||||
"image_sizes": torch.cat(image_sizes_list),
|
||||
}
|
||||
else:
|
||||
image_inputs = {}
|
||||
video_inputs = {}
|
||||
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
|
||||
|
||||
def get_number_tiles_based_on_image_size(
|
||||
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of tiles based on the image size.
|
||||
"""
|
||||
orig_height, orig_width = image_size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
|
||||
)
|
||||
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
if use_thumbnail and tiles_num > 1:
|
||||
tiles_num += 1
|
||||
return tiles_num
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# override to save video-config in a separate config file
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
outputs = super().save_pretrained(save_directory, **kwargs)
|
||||
return outputs
|
||||
|
||||
# override to load video-config from a separate config file
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
|
||||
if isinstance(processor, tuple):
|
||||
processor = processor[0]
|
||||
return processor
|
||||
|
||||
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def process_vision_info(
|
||||
self,
|
||||
conversations: list[dict] | list[list[dict]],
|
||||
return_video_kwargs: bool = False,
|
||||
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
|
||||
vision_infos = self.extract_vision_info(conversations)
|
||||
## Read images or videos
|
||||
image_inputs = []
|
||||
video_inputs = []
|
||||
video_sample_fps_list = []
|
||||
video_timestamps_list = []
|
||||
for vision_info in vision_infos:
|
||||
if "image" in vision_info or "image_url" in vision_info:
|
||||
image_inputs.append(fetch_image(vision_info))
|
||||
else:
|
||||
raise ValueError("image, image_url or video should in content.")
|
||||
if len(image_inputs) == 0:
|
||||
image_inputs = None
|
||||
if len(video_inputs) == 0:
|
||||
video_inputs = None
|
||||
if return_video_kwargs:
|
||||
return (
|
||||
image_inputs,
|
||||
video_inputs,
|
||||
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
|
||||
)
|
||||
return image_inputs, video_inputs
|
||||
|
||||
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||
vision_infos = []
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
for conversation in conversations:
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if (
|
||||
"image" in ele
|
||||
or "image_url" in ele
|
||||
or "video" in ele
|
||||
or ele["type"] in ("image", "image_url", "video")
|
||||
):
|
||||
vision_infos.append(ele)
|
||||
return vision_infos
|
||||
|
||||
|
||||
__all__ = ["Eagle25VLProcessor"]
|
||||
@@ -1,379 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
|
||||
|
||||
from .action_head.flow_matching_action_head import (
|
||||
FlowmatchingActionHead,
|
||||
FlowmatchingActionHeadConfig,
|
||||
)
|
||||
from .utils import ensure_eagle_cache_ready
|
||||
|
||||
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
|
||||
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
|
||||
|
||||
|
||||
class EagleBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
|
||||
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
|
||||
project_to_dim: int = 1536,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tune_llm: whether to tune the LLM model (default: True)
|
||||
tune_visual: whether to tune the visual model (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
|
||||
|
||||
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
|
||||
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
|
||||
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
|
||||
try:
|
||||
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
|
||||
|
||||
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
|
||||
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if project_to_dim is not None:
|
||||
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
|
||||
else:
|
||||
self.eagle_linear = torch.nn.Identity()
|
||||
|
||||
# needed since we don't use these layers. Also saves compute
|
||||
while len(self.eagle_model.language_model.model.layers) > select_layer:
|
||||
self.eagle_model.language_model.model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual)
|
||||
|
||||
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for p in self.parameters():
|
||||
p.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.eagle_model.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.eagle_model.vision_model.requires_grad_(False)
|
||||
self.eagle_model.mlp1.requires_grad_(False)
|
||||
print(f"Tune backbone llm: {self.tune_llm}")
|
||||
print(f"Tune backbone visual: {self.tune_visual}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not tune_llm and not tune_visual:
|
||||
for name, p in self.named_parameters():
|
||||
if p.requires_grad:
|
||||
print(f"Backbone trainable parameter: {name}")
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No backbone trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
if self.eagle_model.language_model and not self.tune_llm:
|
||||
self.eagle_model.language_model.eval()
|
||||
if self.eagle_model.vision_model and not self.tune_visual:
|
||||
self.eagle_model.vision_model.eval()
|
||||
|
||||
def prepare_input(self, batch: dict) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
eagle_prefix = "eagle_"
|
||||
eagle_input = {
|
||||
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
|
||||
}
|
||||
del eagle_input["image_sizes"]
|
||||
|
||||
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
|
||||
eagle_features = eagle_output.hidden_states[self.select_layer]
|
||||
|
||||
eagle_features = self.eagle_linear(eagle_features)
|
||||
return eagle_features, eagle_input["attention_mask"]
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
|
||||
|
||||
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
|
||||
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
|
||||
if self.training and self.tune_visual:
|
||||
dummy_term = torch.tensor(
|
||||
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
|
||||
)
|
||||
for param in self.eagle_model.vision_model.parameters():
|
||||
if param.requires_grad:
|
||||
dummy_term = dummy_term + 0.0 * param.sum()
|
||||
eagle_embeds = eagle_embeds + dummy_term
|
||||
|
||||
return BatchFeature(
|
||||
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
|
||||
) # [B, T2, hidden_size]
|
||||
|
||||
|
||||
BACKBONE_FEATURE_KEY = "backbone_features"
|
||||
ACTION_KEY = "action_pred"
|
||||
LOSS_KEY = "loss"
|
||||
ERROR_MSG = "Error: unexpected input/output"
|
||||
N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
class GR00TN15(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
config_class = GR00TN15Config
|
||||
"""
|
||||
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
|
||||
here n is variable and can be e.g. time, 1 or user specified
|
||||
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
|
||||
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN15Config,
|
||||
local_model_path: str,
|
||||
):
|
||||
assert isinstance(config.backbone_cfg, dict)
|
||||
assert isinstance(config.action_head_cfg, dict)
|
||||
|
||||
super().__init__(config)
|
||||
self.local_model_path = local_model_path
|
||||
|
||||
self.backbone = EagleBackbone(**config.backbone_cfg)
|
||||
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
|
||||
self.action_head = FlowmatchingActionHead(action_head_cfg)
|
||||
|
||||
self.action_horizon = config.action_horizon
|
||||
self.action_dim = config.action_dim
|
||||
self.compute_dtype = config.compute_dtype
|
||||
self.post_init()
|
||||
|
||||
def validate_inputs(self, inputs):
|
||||
# NOTE -- this should be handled internally by the model
|
||||
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
|
||||
|
||||
detected_error = False
|
||||
error_msg = ERROR_MSG
|
||||
if ACTION in inputs:
|
||||
action = inputs[ACTION]
|
||||
# In inference, action may be omitted or None; validate only when it's a tensor.
|
||||
if action is None:
|
||||
pass # allow None during inference
|
||||
elif isinstance(action, torch.Tensor):
|
||||
shape_ok = (
|
||||
len(action.shape) == 3
|
||||
and action.shape[1] == self.action_horizon
|
||||
and action.shape[2] == self.action_dim
|
||||
)
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{action.shape=}"
|
||||
detected_error = True
|
||||
else:
|
||||
# Unexpected non-tensor type provided for action
|
||||
error_msg += f"\nInvalid type for action: {type(action)}"
|
||||
detected_error = True
|
||||
|
||||
if "video" in inputs:
|
||||
video = inputs["video"]
|
||||
type_ok = isinstance(video, np.ndarray)
|
||||
dtype_ok = video.dtype == np.uint8
|
||||
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
|
||||
if not type_ok:
|
||||
error_msg += f"\n{type(video)=}"
|
||||
detected_error = True
|
||||
if not dtype_ok:
|
||||
error_msg += f"\n{video.dtype=}"
|
||||
detected_error = True
|
||||
if not shape_ok:
|
||||
error_msg += f"\n{video.shape=}"
|
||||
detected_error = True
|
||||
|
||||
if detected_error:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
|
||||
fail_backbone = (
|
||||
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
|
||||
)
|
||||
|
||||
if fail_backbone:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
|
||||
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
|
||||
(
|
||||
LOSS_KEY in action_head_outputs and is_training
|
||||
) # there might not be an action prediction during training
|
||||
or (
|
||||
ACTION_KEY in action_head_outputs
|
||||
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
|
||||
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
|
||||
)
|
||||
)
|
||||
|
||||
if fail_action_head:
|
||||
error_msg = ERROR_MSG
|
||||
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
|
||||
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
|
||||
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
|
||||
error_msg += f"\n{self.action_horizon=}"
|
||||
error_msg += f"\n{self.action_dim=}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
|
||||
return action_head_outputs
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
inputs: dict,
|
||||
) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
|
||||
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
|
||||
return action_head_outputs
|
||||
|
||||
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
|
||||
self.validate_inputs(inputs)
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_maybe_dtype(x):
|
||||
# Cast floating tensors to a memory-efficient compute dtype when requested.
|
||||
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
|
||||
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
if getattr(self, "compute_dtype", None) == "bfloat16":
|
||||
return x.to(self.device, dtype=torch.bfloat16)
|
||||
# Fallback: preserve previous behavior if not using bf16 compute
|
||||
return x.to(self.device, dtype=self.action_head.dtype)
|
||||
# Non-floating tensors: move device only
|
||||
return x.to(self.device)
|
||||
|
||||
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
|
||||
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
|
||||
return backbone_inputs, action_inputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
|
||||
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
|
||||
print(f"Tune backbone vision tower: {tune_visual}")
|
||||
print(f"Tune backbone LLM: {tune_llm}")
|
||||
print(f"Tune action head projector: {tune_projector}")
|
||||
print(f"Tune action head DiT: {tune_diffusion_model}")
|
||||
|
||||
# get the current model path being downloaded
|
||||
try:
|
||||
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
|
||||
# saved in ~/.cache/huggingface/hub/
|
||||
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
|
||||
# HFValidationError, RepositoryNotFoundError
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
print(
|
||||
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
|
||||
)
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path, local_model_path=local_model_path, **kwargs
|
||||
)
|
||||
|
||||
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
|
||||
)
|
||||
return pretrained_model
|
||||
@@ -0,0 +1,966 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
"model_name": "nvidia/Cosmos-Reason2-2B",
|
||||
"backbone_model_type": "qwen",
|
||||
"model_revision": None,
|
||||
"tune_top_llm_layers": 0,
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||
"shortest_image_edge": None,
|
||||
"crop_fraction": None,
|
||||
"random_rotation_angle": None,
|
||||
"color_jitter_params": None,
|
||||
"use_albumentations_transforms": True,
|
||||
"extra_augmentation_config": None,
|
||||
"formalize_language": True,
|
||||
"apply_sincos_state_encoding": False,
|
||||
"use_percentiles": True,
|
||||
"use_relative_action": False,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
"action_horizon": 40,
|
||||
"hidden_size": 1024,
|
||||
"input_embedding_dim": 1536,
|
||||
"state_history_length": 1,
|
||||
"add_pos_embed": True,
|
||||
"attn_dropout": 0.2,
|
||||
"use_vlln": True,
|
||||
"max_seq_len": 1024,
|
||||
"use_alternate_vl_dit": True,
|
||||
"attend_text_every_n_blocks": 2,
|
||||
"diffusion_model_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 48,
|
||||
"norm_type": "ada_norm",
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
"output_dim": 1024,
|
||||
"interleave_self_attention": True,
|
||||
},
|
||||
"vl_self_attention_cfg": {
|
||||
"positional_embeddings": None,
|
||||
"num_layers": 4,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 64,
|
||||
"dropout": 0.2,
|
||||
"final_dropout": True,
|
||||
},
|
||||
"num_inference_timesteps": 4,
|
||||
"noise_beta_alpha": 1.5,
|
||||
"noise_beta_beta": 1.0,
|
||||
"noise_s": 0.999,
|
||||
"num_timestep_buckets": 1000,
|
||||
"tune_projector": True,
|
||||
"tune_diffusion_model": True,
|
||||
"tune_vlln": True,
|
||||
"state_dropout_prob": 0.2,
|
||||
"exclude_state": False,
|
||||
"use_mean_std": False,
|
||||
"max_num_embodiments": 32,
|
||||
"rtc_ramp_rate": 6.0,
|
||||
}
|
||||
|
||||
|
||||
class GR00TN17Config(PretrainedConfig):
|
||||
"""Configuration for NVIDIA GR00T N1.7.
|
||||
|
||||
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
|
||||
flow-matching action head. This mirrors the public N1.7 checkpoint config
|
||||
while keeping it local to LeRobot and independent from the external
|
||||
Isaac-GR00T ``gr00t`` Python package.
|
||||
"""
|
||||
|
||||
model_type = "Gr00tN1d7"
|
||||
|
||||
_defaults = GR00T_N1_7_DEFAULTS
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
selected_w = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
"""Two-layer MLP with category-specific weights."""
|
||||
|
||||
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
|
||||
|
||||
The frequency scalar is intentionally created on CPU and then broadcast with
|
||||
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
|
||||
embedding and avoids tiny dtype/device construction differences in parity
|
||||
tests.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
"""Action encoder with category-specific projections and sinusoidal time encoding."""
|
||||
|
||||
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
|
||||
super().__init__()
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, horizon, _ = actions.shape
|
||||
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("Expected `timesteps` to have shape (B,).")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
|
||||
action_emb = self.W1(actions, cat_ids)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
|
||||
return self.W3(x, cat_ids)
|
||||
|
||||
|
||||
class Qwen3Backbone(nn.Module):
|
||||
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
|
||||
|
||||
The public checkpoint stores the action head in the GR00T checkpoint but
|
||||
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
|
||||
keeps the nested HF module layout compatible across transformer versions
|
||||
and exposes the hidden states consumed by the action head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "nvidia/Cosmos-Reason2-2B",
|
||||
tune_llm: bool = False,
|
||||
tune_visual: bool = False,
|
||||
select_layer: int = -1,
|
||||
reproject_vision: bool = False,
|
||||
use_flash_attention: bool = False,
|
||||
load_bf16: bool = False,
|
||||
tune_top_llm_layers: int = 0,
|
||||
trainable_params_fp32: bool = False,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if use_flash_attention:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
|
||||
extra_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
except ImportError:
|
||||
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
|
||||
extra_kwargs["attn_implementation"] = "sdpa"
|
||||
if load_bf16:
|
||||
extra_kwargs["torch_dtype"] = torch.bfloat16
|
||||
|
||||
if load_pretrained_weights:
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
**extra_kwargs,
|
||||
**transformers_loading_kwargs,
|
||||
).eval()
|
||||
else:
|
||||
self.model = self._from_backbone_config(
|
||||
model_name=model_name,
|
||||
model_kwargs=extra_kwargs,
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
self.select_layer = select_layer
|
||||
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
|
||||
if load_bf16 and trainable_params_fp32:
|
||||
for parameter in self.parameters():
|
||||
if parameter.requires_grad:
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
|
||||
) -> None:
|
||||
self.tune_llm = tune_llm
|
||||
self.tune_visual = tune_visual
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_llm:
|
||||
self.language_model.requires_grad_(False)
|
||||
if not tune_visual:
|
||||
self.visual.requires_grad_(False)
|
||||
if tune_top_llm_layers > 0:
|
||||
for layer in self.language_model.layers[-tune_top_llm_layers:]:
|
||||
for parameter in layer.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if self.language_model and not self.tune_llm:
|
||||
self.language_model.eval()
|
||||
if self.visual and not self.tune_visual:
|
||||
self.visual.eval()
|
||||
|
||||
@property
|
||||
def language_model(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).language_model
|
||||
|
||||
@property
|
||||
def visual(self) -> nn.Module:
|
||||
return getattr(self.model, "model", self.model).visual
|
||||
|
||||
def _from_backbone_config(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
config_kwargs: dict[str, Any],
|
||||
) -> nn.Module:
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
if "mm_token_type_ids" in model_input:
|
||||
return
|
||||
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
|
||||
return
|
||||
|
||||
input_ids = model_input.get("input_ids")
|
||||
if input_ids is None:
|
||||
return
|
||||
|
||||
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[input_ids == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[input_ids == video_token_id] = 2
|
||||
|
||||
model_input["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
|
||||
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
|
||||
|
||||
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
|
||||
drops text position ids before calling text-layer flash attention. GR00T
|
||||
N1.7 was aligned against the older Transformers path, where a fourth text
|
||||
position row is forwarded alongside the temporal/height/width rows. Adding
|
||||
the row here preserves the newer multimodal position computation while
|
||||
keeping flash attention on the legacy code path.
|
||||
"""
|
||||
|
||||
if "position_ids" in model_input:
|
||||
return
|
||||
|
||||
qwen3_model = getattr(self.model, "model", self.model)
|
||||
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
|
||||
if compute_3d_position_ids is None:
|
||||
return
|
||||
|
||||
position_ids = compute_3d_position_ids(
|
||||
input_ids=model_input.get("input_ids"),
|
||||
image_grid_thw=model_input.get("image_grid_thw"),
|
||||
video_grid_thw=model_input.get("video_grid_thw"),
|
||||
inputs_embeds=None,
|
||||
attention_mask=model_input.get("attention_mask"),
|
||||
past_key_values=None,
|
||||
mm_token_type_ids=model_input.get("mm_token_type_ids"),
|
||||
)
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
|
||||
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
|
||||
|
||||
model_input["position_ids"] = position_ids
|
||||
|
||||
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
|
||||
|
||||
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
|
||||
Newer releases expose the post-final-norm tensor there instead. Capturing
|
||||
the last decoder layer output directly keeps the N1.7 action head input
|
||||
stable across Transformers versions.
|
||||
"""
|
||||
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
|
||||
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
|
||||
if isinstance(output, torch.Tensor):
|
||||
captured["features"] = output
|
||||
elif isinstance(output, (tuple, list)) and output:
|
||||
captured["features"] = output[0]
|
||||
elif hasattr(output, "last_hidden_state"):
|
||||
captured["features"] = output.last_hidden_state
|
||||
|
||||
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
|
||||
try:
|
||||
outputs = self.model(**model_input, output_hidden_states=True)
|
||||
finally:
|
||||
hook.remove()
|
||||
|
||||
return captured.get("features", outputs.hidden_states[-1])
|
||||
|
||||
def forward(self, vl_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
|
||||
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
|
||||
model_input = {key: vl_input[key] for key in keys_to_use}
|
||||
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
|
||||
self._ensure_mm_token_type_ids(model_input)
|
||||
self._ensure_legacy_qwen3_position_ids(model_input)
|
||||
features = self._last_decoder_layer_output(model_input)
|
||||
image_mask = model_input["input_ids"] == self.model.config.image_token_id
|
||||
attention_mask = model_input["attention_mask"] == 1
|
||||
return BatchFeature(
|
||||
data={
|
||||
"backbone_features": features,
|
||||
"backbone_attention_mask": attention_mask,
|
||||
"image_mask": image_mask,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class GR00TN17ActionHead(nn.Module):
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: GR00TN17Config):
|
||||
require_package("diffusers", extra="groot")
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.input_embedding_dim = config.input_embedding_dim
|
||||
|
||||
if config.use_alternate_vl_dit:
|
||||
self.model = AlternateVLDiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
|
||||
)
|
||||
else:
|
||||
self.model = DiT(
|
||||
**config.diffusion_model_cfg,
|
||||
cross_attention_dim=config.backbone_embedding_dim,
|
||||
)
|
||||
|
||||
self.action_dim = config.max_action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=config.max_state_dim * config.state_history_length,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.input_embedding_dim,
|
||||
)
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.action_dim,
|
||||
hidden_size=self.input_embedding_dim,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
|
||||
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
|
||||
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
|
||||
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
|
||||
else:
|
||||
self.vl_self_attention = nn.Identity()
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
self.state_dropout_prob = config.state_dropout_prob
|
||||
self._noise_beta_alpha = config.noise_beta_alpha
|
||||
self._noise_beta_beta = config.noise_beta_beta
|
||||
self._beta_dist = None
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
|
||||
) -> None:
|
||||
self.tune_projector = tune_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vlln = tune_vlln
|
||||
for parameter in self.parameters():
|
||||
parameter.requires_grad = True
|
||||
if not tune_projector:
|
||||
self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vlln:
|
||||
self.vlln.requires_grad_(False)
|
||||
self.vl_self_attention.requires_grad_(False)
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self) -> None:
|
||||
if self.training:
|
||||
if not self.tune_projector:
|
||||
self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vlln:
|
||||
self.vlln.eval()
|
||||
self.vl_self_attention.eval()
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
if self._beta_dist is None:
|
||||
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
|
||||
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
|
||||
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
|
||||
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
|
||||
backbone_features = self.vlln(backbone_output["backbone_features"])
|
||||
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
|
||||
return backbone_output
|
||||
|
||||
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
vl_embeds = backbone_output.backbone_features
|
||||
device = vl_embeds.device
|
||||
embodiment_id = action_input.embodiment_id
|
||||
|
||||
if action_input.state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = action_input.state.view(action_input.state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, embodiment_id)
|
||||
|
||||
if self.training and self.state_dropout_prob > 0:
|
||||
do_dropout = (
|
||||
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
|
||||
)
|
||||
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
|
||||
|
||||
actions = action_input.action
|
||||
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None]
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output, _ = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
encoder_attention_mask=backbone_output.backbone_attention_mask,
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
data={
|
||||
"loss": loss,
|
||||
"action_loss": action_loss,
|
||||
"action_mask": action_mask,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
|
||||
backbone_output = self.process_backbone_output(backbone_output)
|
||||
state = action_input.state
|
||||
if state.shape[1] != self.config.state_history_length:
|
||||
raise ValueError("state history length does not match GR00T N1.7 config.")
|
||||
state = state.view(state.shape[0], 1, -1)
|
||||
state_features = self.state_encoder(state, action_input.embodiment_id)
|
||||
return BatchFeature(
|
||||
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action_with_features(
|
||||
self,
|
||||
backbone_features: torch.Tensor,
|
||||
state_features: torch.Tensor,
|
||||
embodiment_id: torch.Tensor,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
vl_embeds = backbone_features
|
||||
batch_size = vl_embeds.shape[0]
|
||||
device = vl_embeds.device
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.action_dim),
|
||||
dtype=vl_embeds.dtype,
|
||||
device=device,
|
||||
)
|
||||
dt = 1.0 / self.num_inference_timesteps
|
||||
vel_strength = torch.ones_like(actions)
|
||||
|
||||
if "action" in action_input:
|
||||
if options is None:
|
||||
raise ValueError("RTC options are required when action is provided to get_action.")
|
||||
action_horizon_before_padding = options["action_horizon"]
|
||||
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
|
||||
:,
|
||||
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
|
||||
:,
|
||||
]
|
||||
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
|
||||
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
|
||||
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
|
||||
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
|
||||
ramp = ramp / ramp[-1].clamp_min(1e-8)
|
||||
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
|
||||
None, :, None
|
||||
].to(device)
|
||||
|
||||
for t_step in range(self.num_inference_timesteps):
|
||||
t_cont = t_step / float(self.num_inference_timesteps)
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
|
||||
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
|
||||
if self.config.add_pos_embed:
|
||||
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
|
||||
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
|
||||
sa_embs = torch.cat((state_features, action_features), dim=1)
|
||||
|
||||
if self.config.use_alternate_vl_dit:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
image_mask=backbone_output.image_mask,
|
||||
backbone_attention_mask=backbone_output.backbone_attention_mask,
|
||||
)
|
||||
else:
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embeds,
|
||||
timestep=timesteps_tensor,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"action_pred": actions,
|
||||
"backbone_features": vl_embeds,
|
||||
"state_features": state_features,
|
||||
}
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_action(
|
||||
self,
|
||||
backbone_output: BatchFeature,
|
||||
action_input: BatchFeature,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> BatchFeature:
|
||||
features = self._encode_features(backbone_output, action_input)
|
||||
return self.get_action_with_features(
|
||||
backbone_features=features.backbone_features,
|
||||
state_features=features.state_features,
|
||||
embodiment_id=action_input.embodiment_id,
|
||||
backbone_output=backbone_output,
|
||||
action_input=action_input,
|
||||
options=options,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
|
||||
return BatchFeature(data=batch)
|
||||
|
||||
|
||||
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=True,
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"tie_word_embeddings": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [5, 11, 17],
|
||||
"depth": 24,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1024,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 2048,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
config.model_name,
|
||||
)
|
||||
return Qwen3Backbone
|
||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||
|
||||
|
||||
class GR00TN17(PreTrainedModel):
|
||||
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
|
||||
|
||||
config_class = GR00TN17Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GR00TN17Config,
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
backbone_cls = get_backbone_cls(config)
|
||||
self.backbone = backbone_cls(
|
||||
model_name=config.model_name,
|
||||
tune_llm=config.tune_llm,
|
||||
tune_visual=config.tune_visual,
|
||||
select_layer=config.select_layer,
|
||||
reproject_vision=config.reproject_vision,
|
||||
use_flash_attention=config.use_flash_attention,
|
||||
load_bf16=config.load_bf16,
|
||||
tune_top_llm_layers=config.tune_top_llm_layers,
|
||||
trainable_params_fp32=config.backbone_trainable_params_fp32,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_pretrained_weights=load_backbone_weights,
|
||||
)
|
||||
self.action_head = GR00TN17ActionHead(config)
|
||||
self.post_init()
|
||||
|
||||
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
|
||||
global tree
|
||||
if tree is None:
|
||||
require_package("dm-tree", extra="groot", import_name="tree")
|
||||
tree = importlib.import_module("tree")
|
||||
backbone_inputs = self.backbone.prepare_input(inputs)
|
||||
action_inputs = self.action_head.prepare_input(inputs)
|
||||
|
||||
def to_device_with_dtype(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if torch.is_floating_point(x):
|
||||
return x.to(self.device, dtype=self.dtype)
|
||||
return x.to(self.device)
|
||||
|
||||
return (
|
||||
tree.map_structure(to_device_with_dtype, backbone_inputs),
|
||||
tree.map_structure(to_device_with_dtype, action_inputs),
|
||||
)
|
||||
|
||||
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head(backbone_outputs, action_inputs)
|
||||
|
||||
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
|
||||
backbone_inputs, action_inputs = self.prepare_input(inputs)
|
||||
backbone_outputs = self.backbone(backbone_inputs)
|
||||
return self.action_head.get_action(backbone_outputs, action_inputs, options)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(iter(self.parameters())).dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
tune_visual = kwargs.pop("tune_visual", True)
|
||||
tune_llm = kwargs.pop("tune_llm", False)
|
||||
tune_projector = kwargs.pop("tune_projector", True)
|
||||
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
|
||||
tune_vlln = kwargs.pop("tune_vlln", True)
|
||||
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
|
||||
try:
|
||||
local_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
repo_type="model",
|
||||
revision=kwargs.get("revision"),
|
||||
cache_dir=kwargs.get("cache_dir"),
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
token=kwargs.get("token"),
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError):
|
||||
local_model_path = pretrained_model_name_or_path
|
||||
|
||||
pretrained_model = super().from_pretrained(
|
||||
local_model_path,
|
||||
transformers_loading_kwargs=transformers_loading_kwargs,
|
||||
load_backbone_weights=load_backbone_weights,
|
||||
**kwargs,
|
||||
)
|
||||
pretrained_model.backbone.set_trainable_parameters(
|
||||
tune_visual=tune_visual,
|
||||
tune_llm=tune_llm,
|
||||
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
|
||||
)
|
||||
pretrained_model.action_head.set_trainable_parameters(
|
||||
tune_projector=tune_projector,
|
||||
tune_diffusion_model=tune_diffusion_model,
|
||||
tune_vlln=tune_vlln,
|
||||
)
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
|
||||
try:
|
||||
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
@@ -17,22 +17,13 @@
|
||||
"""
|
||||
Groot Policy Wrapper for LeRobot Integration
|
||||
|
||||
Minimal integration that delegates to Isaac-GR00T components where possible
|
||||
without porting their code. The intent is to:
|
||||
|
||||
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
||||
- Optionally align action horizon similar to gr00t_finetune.py
|
||||
- Expose predict_action via GR00T model.get_action
|
||||
- Provide a training forward that can call the GR00T model forward if batch
|
||||
structure matches.
|
||||
|
||||
Notes:
|
||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||
possible without porting their code. Dataset loading and training
|
||||
orchestration are handled by LeRobot's standard training stack.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
@@ -46,8 +37,19 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_groot import GrootConfig
|
||||
from .groot_n1 import GR00TN15
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_groot import (
|
||||
GROOT_N1_5,
|
||||
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||
GROOT_N1_7,
|
||||
GrootConfig,
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound="GrootPolicy")
|
||||
|
||||
@@ -67,37 +69,35 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Initialize GR00T model using ported components
|
||||
self._groot_model = self._create_groot_model()
|
||||
self._action_queue_steps = self._resolve_action_queue_steps()
|
||||
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
||||
|
||||
This is only called when creating a NEW policy (not when loading from checkpoint).
|
||||
|
||||
Steps (delegating to Isaac-GR00T):
|
||||
1) Download and load pretrained model via GR00TN15.from_pretrained
|
||||
2) Align action horizon with data_config if provided
|
||||
"""
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
model = GR00TN15.from_pretrained(
|
||||
pretrained_model_name_or_path=self.config.base_model_path,
|
||||
tune_llm=self.config.tune_llm,
|
||||
tune_visual=self.config.tune_visual,
|
||||
tune_projector=self.config.tune_projector,
|
||||
tune_diffusion_model=self.config.tune_diffusion_model,
|
||||
)
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
||||
model.config.compute_dtype = model.compute_dtype
|
||||
model = GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -118,7 +118,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
"""Load Groot policy from pretrained model.
|
||||
|
||||
Handles two cases:
|
||||
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||
1. Base GR00T N1.7 models - loads the raw model
|
||||
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||
|
||||
Args:
|
||||
@@ -141,9 +141,15 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
print(
|
||||
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
pretrained_name_or_path,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
@@ -174,7 +180,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if is_finetuned_checkpoint:
|
||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
@@ -190,11 +196,15 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
)
|
||||
|
||||
# This is a base GR00T model - load it fresh
|
||||
print("Detected base GR00T model, loading from HuggingFace...")
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
|
||||
# Add minimal visual feature required for validation
|
||||
# validate_features() will automatically add state and action features
|
||||
@@ -215,6 +225,16 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
raise ValueError(message)
|
||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||
# in __init__ via _create_groot_model()
|
||||
policy = cls(config)
|
||||
@@ -225,21 +245,160 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
execution_horizon = infer_groot_n1_7_action_execution_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
horizons = [n_action_steps]
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
if execution_horizon is not None:
|
||||
horizons.append(execution_horizon)
|
||||
return min(horizons)
|
||||
|
||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||
|
||||
horizons = [actions.shape[1]]
|
||||
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
|
||||
self.config.base_model_path,
|
||||
self.config.embodiment_tag,
|
||||
)
|
||||
if checkpoint_action_horizon is not None:
|
||||
horizons.append(checkpoint_action_horizon)
|
||||
|
||||
for horizon in (self.config.chunk_size, self.config.n_action_steps):
|
||||
horizon = int(horizon)
|
||||
if horizon > 0:
|
||||
horizons.append(horizon)
|
||||
|
||||
return max(1, min(horizons))
|
||||
|
||||
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
if include_action:
|
||||
allowed_base.update({"action", "action_mask"})
|
||||
|
||||
allowed_base.update(
|
||||
{
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"pixel_values",
|
||||
"image_grid_thw",
|
||||
"mm_token_type_ids",
|
||||
"pixel_values_videos",
|
||||
"video_grid_thw",
|
||||
}
|
||||
)
|
||||
allowed_base.add("action_mask")
|
||||
|
||||
return {
|
||||
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
|
||||
def _prepare_n1_7_rtc_inputs(
|
||||
self,
|
||||
inputs: dict[str, Tensor],
|
||||
*,
|
||||
inference_delay: object,
|
||||
prev_chunk_left_over: object,
|
||||
) -> tuple[dict[str, Tensor], dict[str, object] | None]:
|
||||
if prev_chunk_left_over is None:
|
||||
return inputs, None
|
||||
if not isinstance(prev_chunk_left_over, torch.Tensor):
|
||||
raise TypeError("prev_chunk_left_over must be a torch.Tensor for GR00T N1.7 RTC.")
|
||||
if prev_chunk_left_over.numel() == 0:
|
||||
return inputs, None
|
||||
|
||||
prev_actions = prev_chunk_left_over
|
||||
if prev_actions.ndim == 2:
|
||||
prev_actions = prev_actions.unsqueeze(0)
|
||||
elif prev_actions.ndim != 3:
|
||||
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||
|
||||
state = inputs.get("state")
|
||||
if state is None:
|
||||
raise ValueError("GR00T N1.7 RTC requires `state` in the preprocessed batch.")
|
||||
batch_size = state.shape[0]
|
||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||
elif prev_actions.shape[0] != batch_size:
|
||||
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||
|
||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||
# provided prefix row as a real action constraint, so strip that padding
|
||||
# before constructing the native overlap options.
|
||||
valid_prefix_rows = prev_actions.detach().abs().sum(dim=(0, 2)) > 0
|
||||
if valid_prefix_rows.any():
|
||||
valid_prefix_steps = int(valid_prefix_rows.nonzero()[-1].item()) + 1
|
||||
prev_actions = prev_actions[:, :valid_prefix_steps, :]
|
||||
else:
|
||||
return inputs, None
|
||||
|
||||
model_action_horizon = int(
|
||||
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||
)
|
||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||
if prev_actions.shape[1] > model_action_horizon:
|
||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||
|
||||
action_horizon = int(prev_actions.shape[1])
|
||||
if action_horizon <= 0:
|
||||
return inputs, None
|
||||
|
||||
if prev_actions.shape[2] > max_action_dim:
|
||||
prev_actions = prev_actions[:, :, :max_action_dim]
|
||||
elif prev_actions.shape[2] < max_action_dim:
|
||||
pad = torch.zeros(
|
||||
prev_actions.shape[0],
|
||||
prev_actions.shape[1],
|
||||
max_action_dim - prev_actions.shape[2],
|
||||
dtype=prev_actions.dtype,
|
||||
device=prev_actions.device,
|
||||
)
|
||||
prev_actions = torch.cat([prev_actions, pad], dim=2)
|
||||
|
||||
prev_actions = prev_actions.to(device=state.device, dtype=state.dtype)
|
||||
|
||||
rtc_config = getattr(self.config, "rtc_config", None)
|
||||
execution_horizon = int(getattr(rtc_config, "execution_horizon", action_horizon))
|
||||
overlap_steps = max(0, min(action_horizon, execution_horizon))
|
||||
if overlap_steps == 0:
|
||||
return inputs, None
|
||||
|
||||
try:
|
||||
frozen_steps = int(inference_delay or 0)
|
||||
except (TypeError, ValueError):
|
||||
frozen_steps = 0
|
||||
frozen_steps = max(0, min(frozen_steps, overlap_steps))
|
||||
|
||||
options = {
|
||||
"action_horizon": action_horizon,
|
||||
"rtc_overlap_steps": overlap_steps,
|
||||
"rtc_frozen_steps": frozen_steps,
|
||||
"rtc_ramp_rate": float(getattr(self._groot_model.config, "rtc_ramp_rate", 6.0)),
|
||||
}
|
||||
|
||||
inputs = dict(inputs)
|
||||
inputs["action"] = prev_actions
|
||||
return inputs, options
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Training forward pass.
|
||||
|
||||
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
||||
"""
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
||||
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
||||
@@ -248,38 +407,54 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||
loss = outputs.get("loss")
|
||||
if loss is None:
|
||||
raise RuntimeError(
|
||||
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||
"'action' and 'action_mask'; check the preprocessor output."
|
||||
)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: object) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
For N1.7, LeRobot's RTC leftovers are converted into the native GR00T
|
||||
action-overlap options before calling the underlying model.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
||||
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
||||
allowed_base = {"state", "state_mask", "embodiment_id"}
|
||||
groot_inputs = {
|
||||
k: v
|
||||
for k, v in batch.items()
|
||||
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
||||
}
|
||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = next(self.parameters()).device
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
||||
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
if groot_options is not None:
|
||||
outputs = self._groot_model.get_action(groot_inputs, options=groot_options)
|
||||
else:
|
||||
outputs = self._groot_model.get_action(groot_inputs)
|
||||
|
||||
actions = outputs.get("action_pred")
|
||||
|
||||
prediction_horizon = self._resolve_prediction_horizon(actions)
|
||||
actions = actions[:, :prediction_horizon]
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
@@ -292,40 +467,28 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
||||
is compiled against a different PyTorch version than what's currently installed.
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
|
||||
# Set environment variables to handle Flash Attention compatibility
|
||||
# These help with symbol resolution issues
|
||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
||||
|
||||
# Try to import flash_attn and handle failures gracefully
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"[GROOT] Flash Attention not available: {e}")
|
||||
print("[GROOT] Will use fallback attention mechanism")
|
||||
except Exception as e:
|
||||
if "undefined symbol" in str(e):
|
||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
||||
print(" pip uninstall flash-attn")
|
||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
else:
|
||||
print(f"[GROOT] Flash Attention error: {e}")
|
||||
print("[GROOT] Continuing with fallback attention mechanism")
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,47 +0,0 @@
|
||||
from pathlib import Path
|
||||
from shutil import copytree
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
|
||||
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
|
||||
|
||||
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
|
||||
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
vendor_dir = Path(vendor_dir)
|
||||
|
||||
try:
|
||||
# Populate/refresh cache with vendor files to ensure a complete processor directory
|
||||
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
|
||||
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
|
||||
except Exception as exc: # nosec: B110
|
||||
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
|
||||
|
||||
required_assets = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"added_tokens.json",
|
||||
"chat_template.json",
|
||||
"special_tokens_map.json",
|
||||
"config.json",
|
||||
"generation_config.json",
|
||||
"preprocessor_config.json",
|
||||
"processor_config.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
|
||||
|
||||
for fname in required_assets:
|
||||
dst = cache_dir / fname
|
||||
if not dst.exists():
|
||||
print(f"[GROOT] Fetching {fname}")
|
||||
hf_hub_download(
|
||||
repo_id=assets_repo,
|
||||
filename=fname,
|
||||
repo_type="model",
|
||||
local_dir=str(cache_dir),
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
|
||||
inference recipe on lerobot.
|
||||
|
||||
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
|
||||
|
||||
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
|
||||
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
|
||||
(the "high-level subtask prediction" of the paper, §IV.D),
|
||||
* AR text generation at inference (:meth:`PI052Policy.select_message`),
|
||||
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
|
||||
text head against missing context at inference.
|
||||
|
||||
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
|
||||
canonical training recipe and
|
||||
``examples/training/pi052_hirobot.slurm`` for the launcher.
|
||||
"""
|
||||
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .modeling_pi052 import PI052Policy
|
||||
from .processor_pi052 import make_pi052_pre_post_processors
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
__all__ = [
|
||||
"PI052Config",
|
||||
"PI052Policy",
|
||||
"PI052TextTokenizerStep",
|
||||
"make_pi052_pre_post_processors",
|
||||
]
|
||||
@@ -1,235 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
|
||||
hierarchical inference recipe.
|
||||
|
||||
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
|
||||
~300M Gemma action expert, joint training with FAST tokens during
|
||||
pre-train and flow matching during post-train), but with the
|
||||
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
|
||||
to predict both:
|
||||
|
||||
* **subtask strings** at the high level (cross-entropy on the LM
|
||||
head), and
|
||||
* **action chunks** at the low level (flow matching on the
|
||||
action-expert tokens).
|
||||
|
||||
This is the dual-head co-training pattern from the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖²
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which the multi-rate ``PI052Runtime`` (in
|
||||
``lerobot.policies.pi052.inference``) drives at separate rates.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi052")
|
||||
@dataclass
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
Recipe-driven dual-head training: the flow head supervises actions,
|
||||
the LM head supervises subtask / plan / memory / VQA text. The
|
||||
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
|
||||
shipped alongside this policy. Set to ``None`` to disable recipe
|
||||
rendering and fall back to π0.5's single-task ``Task: ... Action:``
|
||||
prompt path (unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template, so we don't apply one. The recipe renderer's output
|
||||
is concatenated as a plain prefix + assistant suffix instead,
|
||||
mirroring how the π0.5 paper's high-level inference samples text
|
||||
auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D uses α=10 between the flow and text terms, assuming
|
||||
# text is a rare auxiliary task. With the recipe stack the flow-only
|
||||
# `low_level` branch fires on a large share of samples, so α=10
|
||||
# swamps the LM head and collapses generation into degenerate
|
||||
# repetition. We use the milder 5:1 split here.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 5.0
|
||||
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
|
||||
flow:text split than the paper's α=10, since the flow-only
|
||||
``low_level`` recipe already gives the action expert frequent
|
||||
gradient. Lower it further if the LM head still underfits."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
|
||||
The existing ``PI05Policy`` zeroes / freezes the head on load
|
||||
because it never reads from it. Must be ``True`` for π0.5-style
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Randomly drop non-target context messages so the LM head learns
|
||||
# to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = True
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. On by default to match the
|
||||
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
|
||||
§III.B-C Eq. 1). Set to False if you only want the
|
||||
post-training-style flow + text recipe."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
|
||||
max_action_tokens: int = 256
|
||||
"""Maximum number of FAST tokens per action chunk."""
|
||||
|
||||
fast_skip_tokens: int = 128
|
||||
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||
collisions with PaliGemma's text vocabulary."""
|
||||
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = False
|
||||
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
|
||||
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
|
||||
base_tokenizer_name, fit_samples)``. On cache miss, it loads
|
||||
``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression.
|
||||
|
||||
Off by default because the fit requires a separate pre-training
|
||||
pass over the dataset (~1-2 min on a medium dataset) and depends
|
||||
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
|
||||
when you want paper-faithful compression; leave off to fall back
|
||||
on the universal ``physical-intelligence/fast`` codebook."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
fast_tokenizer_fit_samples: int = 1024
|
||||
"""Number of action chunks to sample for the fit. The FAST paper uses
|
||||
a few thousand; 1024 is a reasonable default for medium datasets."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# blocked from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implemented in ``modeling_pi052`` as
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
# Learning-rate defaults --------------------------------------------
|
||||
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
|
||||
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
|
||||
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
|
||||
# is the LM-head LR multiplier: pi05 has no text supervision so the
|
||||
# head doesn't get gradients; pi052 always has text supervision
|
||||
# (subtask / memory / VQA) via the recipe, and under KI the LM head
|
||||
# only sees gradients on ~30–45% of the batch (the text-CE mask
|
||||
# share of the recipe). Under aggressive cosine decay this is too
|
||||
# weak to keep the head pinned, so it drifts back toward PaliGemma's
|
||||
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
|
||||
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
|
||||
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
|
||||
# + tied ``embed_tokens`` into their own param group while sharing
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
# while CE stays low, because a uniform additive logit bias cancels in
|
||||
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
|
||||
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
|
||||
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
use_hf_kernels: bool = True
|
||||
"""Deprecated. Liger HF kernels are patched unconditionally by
|
||||
``_enable_hf_kernels`` — this field is retained as a no-op for
|
||||
backward compatibility with checkpoints saved before commit
|
||||
d70c8104 (which still serialize ``use_hf_kernels: true`` into
|
||||
``config.json``). Loading those configs would otherwise raise
|
||||
``DecodingError: The fields use_hf_kernels are not valid for
|
||||
PI052Config`` (job 22164492). Remove in a future major bump."""
|
||||
|
||||
# Optimizer foreach/fused. pi052 carries these locally because the shared
|
||||
# PI05Config (kept identical to upstream main) does not define them; the
|
||||
# checkpoints we train serialize both keys into config.json, so they must
|
||||
# be valid PI052Config fields and flow into the AdamW preset below.
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
foreach=self.optimizer_foreach,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
# we're training it. Override the π0.5 default
|
||||
# (``train_expert_only=True``) unless the user explicitly opts
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
@@ -1,304 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset-specific FAST action tokenizer fitting.
|
||||
|
||||
The published ``physical-intelligence/fast`` tokenizer is a *universal*
|
||||
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
|
||||
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
|
||||
π0.5 itself, the recommended practice is to **finetune the tokenizer on
|
||||
your specific dataset's action distribution** before training the
|
||||
policy — same way one would adapt a language tokenizer to a domain
|
||||
corpus. Without this finetune step, action sequences from your robot
|
||||
may require more tokens per chunk than necessary, lowering effective
|
||||
compression and slowing convergence of the action-CE loss.
|
||||
|
||||
This module provides a single utility, :func:`fit_fast_tokenizer`,
|
||||
that does the finetune. The training entry point invokes it
|
||||
automatically when the policy's ``enable_fast_action_loss`` and
|
||||
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
|
||||
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
|
||||
|
||||
The fitted tokenizer is saved to
|
||||
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
|
||||
runs over the same dataset re-use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
|
||||
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
|
||||
# that's the image / feature-extractor convention). Centralised here so
|
||||
# the cache-hit check and the rank-N readiness wait agree on the same
|
||||
# sentinel.
|
||||
_CACHE_SENTINEL = "processor_config.json"
|
||||
|
||||
|
||||
def _dataset_signature(
|
||||
dataset_repo_id: str,
|
||||
base_tokenizer_name: str,
|
||||
n_samples: int,
|
||||
chunk_size: int,
|
||||
) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count, chunk size) so any
|
||||
of those changing re-runs the fit. ``chunk_size`` matters because
|
||||
the tokenizer is fit on chunks of that length.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(chunk_size).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def fit_fast_tokenizer(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
cache_dir: str | Path,
|
||||
base_tokenizer_name: str = "physical-intelligence/fast",
|
||||
n_samples: int = 1024,
|
||||
chunk_size: int = 50,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
|
||||
cache_dir: Directory under which to save (and look up) fitted
|
||||
tokenizers. The actual save path is
|
||||
``{cache_dir}/{signature}``.
|
||||
base_tokenizer_name: HF identifier for the base FAST tokenizer
|
||||
to finetune from. ``physical-intelligence/fast`` is the
|
||||
universal one.
|
||||
n_samples: Number of action chunks to sample for the fit. The
|
||||
FAST paper uses a few thousand; ``1024`` is a good default
|
||||
for medium datasets.
|
||||
chunk_size: Length of each action chunk (matches
|
||||
``policy.chunk_size``). The FAST tokenizer is fit on
|
||||
sequences of this length.
|
||||
seed: RNG seed for sample selection.
|
||||
|
||||
Returns:
|
||||
The local path to the fitted tokenizer. Passed directly to
|
||||
``--policy.action_tokenizer_name`` for the training run.
|
||||
|
||||
Raises:
|
||||
ImportError: If the ``transformers`` library doesn't expose
|
||||
``AutoProcessor`` or the FAST tokenizer doesn't have a
|
||||
``.fit()`` method (then you're on an older FAST snapshot —
|
||||
update to the current published model).
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
|
||||
logger.info(
|
||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||
"dataset=%s base=%s n_samples=%d",
|
||||
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
|
||||
)
|
||||
return str(out_dir)
|
||||
|
||||
# DDP-safe fit: only the (local) main process actually fits + saves;
|
||||
# other ranks poll the cache sentinel until the leader is done.
|
||||
# Without this guard, all N ranks fit concurrently and race on
|
||||
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
|
||||
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
|
||||
# and compiles a ``.pyc`` — concurrent writers occasionally produce
|
||||
# a stale / partial ``.pyc`` and the subsequent ``from .. import
|
||||
# UniversalActionProcessor`` raises ``AttributeError``.
|
||||
is_leader = (
|
||||
int(os.environ.get("RANK", "0")) == 0
|
||||
and int(os.environ.get("LOCAL_RANK", "0")) == 0
|
||||
)
|
||||
if not is_leader:
|
||||
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
|
||||
start = time.monotonic()
|
||||
while not (out_dir / _CACHE_SENTINEL).exists():
|
||||
if time.monotonic() - start > timeout_s:
|
||||
raise RuntimeError(
|
||||
f"FAST tokenizer fit: non-leader rank timed out after "
|
||||
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
|
||||
"Leader rank likely crashed during the fit."
|
||||
)
|
||||
time.sleep(2.0)
|
||||
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
|
||||
return str(out_dir)
|
||||
|
||||
logger.info(
|
||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
#
|
||||
# Actions are read straight from the underlying HF dataset's
|
||||
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||
# training item (delta-timestamp expansion + video decode + image
|
||||
# transforms); a single bad video frame would then throw and, since
|
||||
# the failure was swallowed at debug level, silently starve the fit
|
||||
# of every chunk. The action column carries no video, so reading it
|
||||
# directly is both faster and immune to decode errors.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Resolve the dataset's data parquet shards directly, sidestepping
|
||||
# ``LeRobotDataset(repo_id, episodes=[N])`` which on v3-format
|
||||
# datasets routes through HF datasets'' split lookup and raises
|
||||
# ``ValueError: Instruction "train" corresponds to no data!`` for
|
||||
# every episode (job 22182985 looped through 13,293 skipped episodes
|
||||
# for ~2.5 h before NCCL killed it). Reading the ``action`` column
|
||||
# straight from the parquet shards is also faster: each per-episode
|
||||
# ``LeRobotDataset`` instantiation re-parses every meta file.
|
||||
from huggingface_hub import snapshot_download # noqa: PLC0415
|
||||
import pyarrow as _pa # noqa: PLC0415
|
||||
import pyarrow.parquet as _pq # noqa: PLC0415
|
||||
|
||||
snap = Path(snapshot_download(repo_id=dataset_repo_id, repo_type="dataset"))
|
||||
data_files = sorted((snap / "data").glob("chunk-*/file-*.parquet"))
|
||||
if not data_files:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: no ``data/chunk-*/file-*.parquet`` shards found under {snap!s}."
|
||||
)
|
||||
|
||||
# Read just the (episode_index, action) columns once across all
|
||||
# shards. This is the same pattern used elsewhere in the codebase
|
||||
# for whole-dataset audits and stays under ~2 GB even on 32 k-episode
|
||||
# / 29 M-frame datasets because the action column is a fixed-length
|
||||
# float vector.
|
||||
tables = [_pq.read_table(f, columns=["episode_index", "action"]) for f in data_files]
|
||||
table = _pa.concat_tables(tables)
|
||||
eps = table["episode_index"].to_numpy()
|
||||
acts_col = table["action"]
|
||||
# ``action`` may be a fixed-shape ListArray or a 2-D NumericArray;
|
||||
# ``to_numpy(zero_copy_only=False)`` produces an object array of
|
||||
# 1-D NumPy actions either way, which we stack into (N, D).
|
||||
try:
|
||||
acts = np.stack(acts_col.to_numpy(zero_copy_only=False)).astype(np.float32)
|
||||
except Exception: # noqa: BLE001
|
||||
# Fallback path for nested-list types: flatten via to_pylist().
|
||||
acts = np.asarray(acts_col.to_pylist(), dtype=np.float32)
|
||||
if acts.ndim != 2:
|
||||
raise RuntimeError(
|
||||
f"FAST fit: expected ``action`` rows to be 1-D vectors; got shape {acts.shape}."
|
||||
)
|
||||
|
||||
# Episode index → slice (start, stop) into ``acts`` along axis 0.
|
||||
# ``eps`` is monotonically increasing within each parquet shard but
|
||||
# we make no assumption across shards — sort once and group.
|
||||
order = np.argsort(eps, kind="stable")
|
||||
eps_sorted = eps[order]
|
||||
boundaries = np.searchsorted(eps_sorted, np.arange(int(eps_sorted.max()) + 2))
|
||||
ep_to_slice: dict[int, tuple[int, int]] = {
|
||||
int(ep): (int(boundaries[ep]), int(boundaries[ep + 1]))
|
||||
for ep in range(len(boundaries) - 1)
|
||||
if boundaries[ep] < boundaries[ep + 1]
|
||||
}
|
||||
num_episodes = len(ep_to_slice)
|
||||
# ``acts`` is in original (un-sorted-by-episode) row order; reorder
|
||||
# so per-episode slices are contiguous.
|
||||
acts = acts[order]
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
ep_indices = list(ep_to_slice.keys())
|
||||
for ep_idx in rng.permutation(ep_indices):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
start, stop = ep_to_slice[int(ep_idx)]
|
||||
ep_actions = acts[start:stop]
|
||||
if ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
eps_visited += 1
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||
"lengths."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
|
||||
logger.info(
|
||||
"FAST fit: collected %d chunks of shape %s from %d episodes",
|
||||
actions.shape[0], actions.shape[1:], eps_visited,
|
||||
)
|
||||
|
||||
# Quantile-normalise per dimension before fitting.
|
||||
#
|
||||
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
|
||||
# rounds to integer tokens; the integer *range* must fit the
|
||||
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
|
||||
# ticks) blow that range up — hence "Vocab size 1024 is too small".
|
||||
# More importantly, at training time ``ActionTokenizerProcessorStep``
|
||||
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
|
||||
# encodes normalised actions. Fitting on raw actions would mismatch
|
||||
# that space. We replicate QUANTILES normalisation here (per-dim
|
||||
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
|
||||
# encode see the same distribution.
|
||||
flat = actions.reshape(-1, actions.shape[-1])
|
||||
q01 = np.quantile(flat, 0.01, axis=0)
|
||||
q99 = np.quantile(flat, 0.99, axis=0)
|
||||
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
|
||||
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
|
||||
|
||||
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
|
||||
if not hasattr(base, "fit"):
|
||||
raise ImportError(
|
||||
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
|
||||
"method — your transformers / model snapshot is too old. Update "
|
||||
"to the current ``physical-intelligence/fast`` revision."
|
||||
)
|
||||
|
||||
fitted = base.fit(actions)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fitted.save_pretrained(str(out_dir))
|
||||
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
|
||||
return str(out_dir)
|
||||
@@ -1,73 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
|
||||
``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import PI052Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"PI052Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
@@ -1,105 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the PI052 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
"/action" / "/pause" → set state["mode"]
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# Slash commands flip the run mode. ``/pause`` stops the action
|
||||
# loop (the action steps gate on ``state["mode"]``); ``/action``
|
||||
# resumes it.
|
||||
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
|
||||
state["mode"] = "action"
|
||||
return
|
||||
if lower in {"/pause", "/p"}:
|
||||
state["mode"] = "paused"
|
||||
queue = state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[pi052] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PI052Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Subtask + memory + VQA configuration. Pipeline:
|
||||
#
|
||||
# HighLevelSubtaskFwd → generate the next subtask via the LM
|
||||
# head at ~``high_level_hz``; writes
|
||||
# ``current_subtask`` and emits
|
||||
# ``subtask_change`` on a transition.
|
||||
# MemoryUpdateFwd → on ``subtask_change``, refresh
|
||||
# ``current_memory`` from the
|
||||
# ``memory_update`` head.
|
||||
# AskVQAFwd → answer camera-grounded stdin questions.
|
||||
# LowLevelForward → action chunk conditioned on the
|
||||
# generated ``current_subtask``.
|
||||
# DispatchAction → drain the chunk to the robot.
|
||||
# DispatchToolCalls → fire any pending tool calls.
|
||||
#
|
||||
# Order matters: ``HighLevelSubtaskFwd`` must run before
|
||||
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
|
||||
# both must run before ``LowLevelForward`` (which is gated on
|
||||
# "action queue empty") so the chunk consumes the freshest
|
||||
# subtask. ``UserInterjectionFwd`` is still importable but
|
||||
# disabled until plan generation is wired in.
|
||||
self.pipeline = [
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
# Listens for the ``subtask_change`` event raised by
|
||||
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
|
||||
MemoryUpdateFwd(
|
||||
trigger=EventTrigger("subtask_change"),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[pi052] runtime stopped", flush=True)
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
mode str "action" (run the robot) | "paused"
|
||||
(action loop stopped — robot holds)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"mode": "action",
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
@@ -1,955 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the PI052 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
doesn't fire, the step is a no-op and the runtime moves on.
|
||||
|
||||
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
|
||||
|
||||
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||
action chunk; trained by
|
||||
``low_level_execution``
|
||||
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||
* ``DispatchAction`` — pops one action per control tick and
|
||||
forwards to the robot
|
||||
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||
next subtask; trained by
|
||||
``high_level_subtask``
|
||||
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||
``memory_update``
|
||||
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||
``user_interjection_response``
|
||||
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||
``ask_vqa_*``
|
||||
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||
the matching ``Tool`` instance
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log, set_if_changed, take_event
|
||||
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step base + runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStep:
|
||||
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||
|
||||
trigger: Trigger
|
||||
|
||||
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.trigger.should_fire(state["_tick"], state):
|
||||
return state
|
||||
return self.run(state) or state
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level (action) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LowLevelForward(InferenceStep):
|
||||
"""Run the policy's action head and produce one action chunk."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Callable ``() -> dict``: returns the current observation batch
|
||||
(already preprocessed). Typically wraps the robot's camera /
|
||||
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
# ``/vlm`` mode pauses the whole action loop so the robot holds
|
||||
# position while the operator probes the VLM with VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
# PI052 produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). Every step gets dispatched to the robot;
|
||||
# popping one per dispatch tick is essentially free. Only
|
||||
# generate a new chunk once the previous one has fully
|
||||
# drained — this is the canonical "sense → think → act"
|
||||
# loop. Refreshing while a chunk is still queued causes the
|
||||
# new chunk to "telescope" past the old one (planned from an
|
||||
# observation that's already 25+ steps stale by the time it
|
||||
# starts dispatching).
|
||||
queue = state.setdefault("action_queue", [])
|
||||
if len(queue) > 0:
|
||||
return None
|
||||
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# The action expert is conditioned on the SUBTASK generated by
|
||||
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
|
||||
# the pipeline and writes ``current_subtask``). Matches the
|
||||
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
|
||||
# Falls back to the task string only on the very first frame,
|
||||
# before the high-level loop has produced a subtask.
|
||||
subtask = state.get("current_subtask") or state.get("task") or ""
|
||||
ctx = [{"role": "user", "content": subtask}]
|
||||
# ``add_generation_prompt=False`` to match the training-time
|
||||
# prefix shape: at training the action expert sees the rendered
|
||||
# user turn ending at ``<|im_end|>`` (no trailing
|
||||
# ``<|im_start|>assistant\n``). Passing True here would append
|
||||
# extra role-marker tokens the action expert never saw during
|
||||
# training.
|
||||
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
|
||||
try:
|
||||
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||
# step so DispatchAction at ctrl_hz can drain them
|
||||
# smoothly until the next refresh.
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"predict_action_chunk failed: %s",
|
||||
exc,
|
||||
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
push_log(
|
||||
state,
|
||||
f" [warn] predict_action_chunk failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return None
|
||||
|
||||
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||
# action executor's batch-squeeze logic works unchanged.
|
||||
if chunk.ndim == 3:
|
||||
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||
elif chunk.ndim == 2:
|
||||
chunk_iter = chunk
|
||||
else:
|
||||
chunk_iter = chunk.unsqueeze(0)
|
||||
|
||||
for step in chunk_iter:
|
||||
queue.append(step.unsqueeze(0))
|
||||
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchAction(InferenceStep):
|
||||
"""Pop one action per tick and hand it to the robot.
|
||||
|
||||
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||
instead of executed.
|
||||
|
||||
Wall-clock catch-up: the action queue represents an open-loop
|
||||
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
|
||||
When the main loop stalls — e.g. an LLM call for the high-level
|
||||
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
|
||||
once over that whole interval. Naively popping a single entry per
|
||||
fire makes the robot lag further and further behind the planned
|
||||
timeline, and a 50-step chunk would take ~125 s to drain instead
|
||||
of ~1.7 s. Track real elapsed time between dispatches and pop
|
||||
``round(elapsed * hz)`` entries, sending the most recent one. The
|
||||
skipped intermediate joint targets are stale anyway — the dynamixel
|
||||
will smooth toward the latest goal position.
|
||||
"""
|
||||
|
||||
robot_executor: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||
_last_dispatch_t: float | None = field(default=None, init=False)
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
# ``/vlm`` mode pauses dispatch — the robot holds its last
|
||||
# commanded position while the operator runs VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
# Reset wall-clock anchor when the queue is empty so the
|
||||
# next chunk doesn't see a huge fake "elapsed" window.
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
now = _time.monotonic()
|
||||
hz = getattr(self.trigger, "hz", 30.0)
|
||||
if self._last_dispatch_t is None or hz <= 0:
|
||||
n_to_pop = 1
|
||||
else:
|
||||
elapsed = now - self._last_dispatch_t
|
||||
# ``max(1, ...)`` so we always pop at least one when the
|
||||
# trigger fires; ``min(len(queue), ...)`` so we don't run
|
||||
# off the end of the chunk.
|
||||
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
|
||||
self._last_dispatch_t = now
|
||||
|
||||
# Drain ``n_to_pop`` stale entries, keep only the latest as the
|
||||
# action actually sent. The intermediate joint targets would
|
||||
# all be ~10–30 ms apart in chunk time — the robot can't track
|
||||
# them individually anyway when the host loop is slow.
|
||||
latest = None
|
||||
for _ in range(n_to_pop):
|
||||
if not queue:
|
||||
break
|
||||
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
|
||||
|
||||
if latest is not None and self.robot_executor is not None:
|
||||
self.robot_executor(latest)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level (text) paths — all use policy.select_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_LOC_TOKENIZER_CACHE: dict[str, Any] = {}
|
||||
|
||||
|
||||
def _get_loc_tokenizer(tok_name: str, auto_tokenizer_cls: Any, register_loc_fn: Any) -> Any:
|
||||
"""Return a loc-token-registered tokenizer, loading from disk only once.
|
||||
|
||||
``AutoTokenizer.from_pretrained`` + loc-token registration is expensive and
|
||||
the result is immutable, so cache per ``tok_name``.
|
||||
"""
|
||||
tokenizer = _LOC_TOKENIZER_CACHE.get(tok_name)
|
||||
if tokenizer is None:
|
||||
tokenizer = register_loc_fn(auto_tokenizer_cls.from_pretrained(tok_name))
|
||||
_LOC_TOKENIZER_CACHE[tok_name] = tokenizer
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _build_text_batch(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||
|
||||
PI052's backbone (PaliGemma) ships no chat template, so we train on
|
||||
a plain role-prefixed concatenation built by
|
||||
``PI052TextTokenizerStep``. We reuse that exact formatter so the
|
||||
inference prefix matches training; ``add_generation_prompt`` appends
|
||||
the bare ``Assistant: `` header the LM head continues from.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
_strip_blocks,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
# Register PaliGemma's <locDDDD> tokens so inference encoding /
|
||||
# decoding sees them as single vocab ids — must match training.
|
||||
# The tokenizer is read-only after registration, so cache it: rebuilding it
|
||||
# from disk on every call dominated eval runtime (this runs twice per env
|
||||
# per replan — subtask gen + action prompt).
|
||||
tokenizer = _get_loc_tokenizer(tok_name, AutoTokenizer, register_paligemma_loc_tokens)
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
if add_generation_prompt:
|
||||
prompt = prompt + "Assistant: "
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||
attn = attn.bool()
|
||||
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
new = dict(m)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask.
|
||||
|
||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
user: "Current subtask: ${subtask}" (if subtask present)
|
||||
↓ generate ↓
|
||||
assistant: <next subtask>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||
set, the resulting observation is merged into ``select_message``'s
|
||||
batch so text generation runs against real video + state."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
# ``/vlm`` mode pauses subtask generation along with the rest of
|
||||
# the action loop.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
||||
# the action queue is empty (i.e. right before LowLevelForward
|
||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
||||
# and running it every loop iteration starves DispatchAction
|
||||
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
|
||||
# of 30/sec and the robot barely moves. Tying it to the same
|
||||
# "queue empty" condition as the chunk refresh produces a
|
||||
# clean sense → think → act cycle.
|
||||
#
|
||||
# Rearm the trigger when skipping so a low-hz schedule
|
||||
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
|
||||
# the slot: the trigger fires once on the timer but the brief
|
||||
# queue-empty window almost never coincides, so without rearm
|
||||
# HL would effectively never run.
|
||||
queue = state.get("action_queue") or []
|
||||
if len(queue) > 0:
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
# Per-chunk-boundary throttle: at each "queue empty" moment we
|
||||
# increment a counter; subtask gen only fires once the counter
|
||||
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
|
||||
# 5 action chunks per subtask-gen so the LM head doesn't churn
|
||||
# every 1.7 s (a fresh subtask while the previous one is still
|
||||
# being executed is wasted compute *and* causes the action
|
||||
# expert's flow trajectory to be re-planned mid-grasp).
|
||||
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
|
||||
# Initialise so the first chunk boundary fires immediately
|
||||
# (counter starts at chunks_per_gen, decrements per skip,
|
||||
# generates and resets when it hits 0).
|
||||
if "_hl_chunks_until_gen" not in state:
|
||||
state["_hl_chunks_until_gen"] = 0
|
||||
if state["_hl_chunks_until_gen"] > 0:
|
||||
state["_hl_chunks_until_gen"] -= 1
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
||||
# suppression — matches training. Operator can override via
|
||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
||||
# on the CLI; useful for under-trained checkpoints whose LM
|
||||
# head still favours EOS at position 0 (pre-trained chat
|
||||
# backbone's short-turn prior hasn't been fully overridden
|
||||
# by the fine-tuning supervision yet).
|
||||
msg = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
# Subtasks never legitimately contain PaliGemma ``<loc>``
|
||||
# tokens — suppress them so a checkpoint whose LM head
|
||||
# has drifted toward the pretrained loc-prior falls back
|
||||
# to its (still-correct) text mass.
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
# repeats. Memorisation collapse looks like "same accepted
|
||||
# subtask N times in a row" or "gibberish_count rising while
|
||||
# current_subtask is stuck". The state panel renders these.
|
||||
state["last_subtask_raw"] = msg or ""
|
||||
# Persistent empty completion is its own failure mode (model
|
||||
# immediately EOS-es from the chat-template generation
|
||||
# prompt) — surface it once every N occurrences so the
|
||||
# operator can distinguish "generation failing silently"
|
||||
# from "generating fine but filter rejecting".
|
||||
if not msg:
|
||||
empties = state.get("subtask_empty_count", 0) + 1
|
||||
state["subtask_empty_count"] = empties
|
||||
if empties == 1 or empties % 5 == 0:
|
||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||
if debug:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||
)
|
||||
else:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"no tokens generated (head EOS-ing before any "
|
||||
"non-special token).",
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
# Bump a counter so the operator can see the model is
|
||||
# struggling without spamming the log every tick. A first
|
||||
# rejection still logs once so the failure is visible.
|
||||
count = state.get("subtask_gibberish_count", 0) + 1
|
||||
state["subtask_gibberish_count"] = count
|
||||
if count == 1 or count % 30 == 0:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if msg:
|
||||
prev_subtask = state.get("current_subtask")
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Stash the just-completed subtask so ``MemoryUpdateFwd``
|
||||
# can drop it into its prompt as ``Completed subtask:``
|
||||
# — the recipe binds ``completed_subtask`` to
|
||||
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
|
||||
# that was active *before* the change.
|
||||
if prev_subtask:
|
||||
state["prior_subtask"] = prev_subtask
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
state["subtask_repeat_count"] = 0
|
||||
else:
|
||||
# Same accepted string regenerated — memorisation tell.
|
||||
# Once this counter climbs past a few, you're seeing
|
||||
# the model unable to move past the current subtask
|
||||
# despite the chunk having drained (visual scene may
|
||||
# have changed but the LM is replaying training
|
||||
# tokens).
|
||||
state["subtask_repeat_count"] = (
|
||||
state.get("subtask_repeat_count", 0) + 1
|
||||
)
|
||||
# Silently skip empty completions — common when the model
|
||||
# warms up or generates only EOS; logging it every tick at
|
||||
# ctrl_hz is just noise.
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory.
|
||||
|
||||
Mirrors the ``memory_update`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||
↓ generate ↓
|
||||
assistant: <new memory>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
count = state.get("memory_gibberish_count", 0) + 1
|
||||
state["memory_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||
|
||||
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||
user: "${interjection}" (the new utterance)
|
||||
↓ generate ↓
|
||||
assistant: <plan + <say>...</say>>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
# MPS during warm-up and floods the panel. The user can
|
||||
# re-trigger by typing again.
|
||||
return None
|
||||
if _looks_like_gibberish(out):
|
||||
count = state.get("plan_gibberish_count", 0) + 1
|
||||
state["plan_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
|
||||
)
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text and _looks_like_gibberish(plan_text):
|
||||
plan_text = ""
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
push_log(state, f" speech: {speech_text}")
|
||||
state.setdefault("tool_calls_pending", []).append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||
}
|
||||
)
|
||||
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||
# Mark interjection consumed.
|
||||
state["recent_interjection"] = None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA.
|
||||
|
||||
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||
turn carrying just the VQA question, plus the camera image block
|
||||
in training (we drop the image at inference because the dataset's
|
||||
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||
|
||||
user: <question>
|
||||
↓ generate ↓
|
||||
assistant: <vqa answer>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||
return None
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
# the answer as-is — the VQA panel line lets the user judge.
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchToolCalls(InferenceStep):
|
||||
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
take_event(state, "tool_call_pending")
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
for call in pending:
|
||||
try:
|
||||
fn = (call or {}).get("function") or {}
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
tool = self.tools.get(name)
|
||||
if tool is None:
|
||||
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||
continue
|
||||
tool.call(args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||
state["tool_calls_pending"] = []
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _looks_like_gibberish(text: str) -> bool:
|
||||
"""Heuristically detect generation that's clearly off the rails.
|
||||
|
||||
Memorised models can collapse to dominant-mode outputs when the
|
||||
prompt drifts even slightly from training distribution. Reject:
|
||||
|
||||
* empty / whitespace-only
|
||||
* too few alphabetic characters (mostly punctuation)
|
||||
* a single character repeated past the threshold
|
||||
* starts with ``":"`` and contains no letters
|
||||
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
|
||||
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
|
||||
where the model emits one or two memorised tokens repeatedly)
|
||||
* chat-template fragment leakage (``Assistant:``, ``User:``,
|
||||
``Ass\\n``)
|
||||
|
||||
Real subtasks look like ``"close the gripper to grasp the blue
|
||||
cube"`` — multiple unique alphabetic tokens, no role-marker
|
||||
fragments. Anything materially shorter than that is rejected.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
stripped = text.strip()
|
||||
alpha = sum(1 for c in stripped if c.isalpha())
|
||||
if alpha < max(3, len(stripped) // 8):
|
||||
return True
|
||||
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||
return True
|
||||
# Single repeating char: e.g. ``""""""``.
|
||||
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||
return True
|
||||
# Chat-template fragment leakage — the model emits ``Ass``,
|
||||
# ``Assistant:``, ``User:``, often with extra newlines/colons.
|
||||
# Reject if the cleaned text is mostly role-marker shards.
|
||||
cleaned = stripped.replace("\n", " ").replace(":", " ")
|
||||
for marker in ("Assistant", "User", "Ass "):
|
||||
if marker in cleaned and len(cleaned.split()) < 4:
|
||||
return True
|
||||
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
|
||||
unique_alpha = {t.lower() for t in tokens}
|
||||
# Short degenerate output — model stuck on ``the`` or a couple of
|
||||
# memorised single-token continuations.
|
||||
if len(unique_alpha) < 3 and len(stripped) < 80:
|
||||
return True
|
||||
# Long repetition collapse — the LM head loops an n-gram for the
|
||||
# whole generation budget ("the arm the arm … the the the the").
|
||||
# Length-independent: many tokens but a tiny unique ratio. The
|
||||
# earlier ``< 80`` check missed these because the looped string
|
||||
# blows well past 80 chars.
|
||||
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
include_completed: bool = False,
|
||||
extra_user: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a chat-template-ready prompt from current runtime state.
|
||||
|
||||
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
|
||||
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||
"""
|
||||
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
|
||||
# values — to mirror the training-time recipe substitution.
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
|
||||
if include_completed and state.get("current_subtask"):
|
||||
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||
head = "\n".join(parts)
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||
if extra_user:
|
||||
msgs.append({"role": "user", "content": extra_user})
|
||||
return msgs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
|
||||
# prompt at inference matches what the model saw during training.
|
||||
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||
# for ad-hoc callers but the four high-level steps now use these.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hirobot_user_head(state: dict[str, Any]) -> str:
|
||||
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
|
||||
|
||||
Mirrors what the recipe renders at training time, where
|
||||
``language_render._substitute`` substitutes empty strings for
|
||||
missing ``${plan}`` / ``${memory}`` bindings — i.e. the
|
||||
``Plan: `` / ``Memory: `` prefix labels are *always* in the
|
||||
user turn, even when their values aren't set yet. Skipping them
|
||||
here (the previous behaviour) produced a different prompt shape
|
||||
on early frames before plan / memory are populated and on
|
||||
samples where the dataset has no plan / memory annotation.
|
||||
"""
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
return f"{task}\nPlan: {plan}\nMemory: {memory}"
|
||||
|
||||
|
||||
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``high_level_subtask`` recipe layout — predict the subtask from the
|
||||
task. The v-current recipe's user turn is just ``${task}`` (plan and
|
||||
memory are not trained), so the inference prompt is the bare task —
|
||||
no ``Plan: `` / ``Memory: `` lines.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
|
||||
|
||||
Recipe layout (``subtask_mem.yaml``):
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if_present prior)
|
||||
user: "Completed subtask: ${completed}" (if_present completed)
|
||||
assistant: → predicts new memory
|
||||
|
||||
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
|
||||
``state['current_memory']`` is the memory the policy last emitted
|
||||
(= the ``prior_memory`` binding at training), and
|
||||
``state['prior_subtask']`` is the subtask that just got replaced
|
||||
(= the ``completed_subtask`` binding at training).
|
||||
"""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""},
|
||||
]
|
||||
prior_memory = state.get("current_memory")
|
||||
if prior_memory:
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
|
||||
)
|
||||
completed_subtask = state.get("prior_subtask")
|
||||
if completed_subtask:
|
||||
msgs.append(
|
||||
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``user_interjection_response`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_plan"):
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||
)
|
||||
interjection = state.get("recent_interjection")
|
||||
if interjection:
|
||||
msgs.append({"role": "user", "content": interjection})
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``plan_generation`` recipe layout — bare task → plan.
|
||||
|
||||
The assistant turn is the generation target, so we only render
|
||||
the user turn at inference; the runtime appends the predicted
|
||||
plan after sampling.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||
return [{"role": "user", "content": question}]
|
||||
|
||||
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
Errors from the provider are logged at debug level and swallowed —
|
||||
text generation still runs (in text-only mode) so a flaky frame
|
||||
source doesn't kill the REPL.
|
||||
"""
|
||||
if provider is None:
|
||||
return None
|
||||
try:
|
||||
return provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_with_policy(
|
||||
policy: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
suppress_loc_tokens: bool = False,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
When ``observation`` carries ``observation.images.*`` and
|
||||
``observation.state``, those are merged into the batch so
|
||||
``select_message`` runs the same VLM prefix the policy was trained
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
|
||||
Failures are surfaced both to the module logger (``warning``) and,
|
||||
when ``state`` is given, to the runtime's user-visible log via
|
||||
:func:`push_log`, so the REPL no longer "looks dead" when
|
||||
something goes wrong inside generation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] policy has no select_message — skipping {label}")
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
try:
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
kwargs: dict[str, Any] = {
|
||||
"tokenizer": text_batch["tokenizer"],
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
|
||||
return policy.select_message(batch, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
|
||||
return ""
|
||||
|
||||
|
||||
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||
|
||||
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||
deterministic textual marker so prefix-LM-style training learns to
|
||||
emit it. The runtime parses it back here. If no marker is present,
|
||||
the entire text is treated as plan with no speech.
|
||||
"""
|
||||
if not text:
|
||||
return "", ""
|
||||
match = _SAY_RE.search(text)
|
||||
if not match:
|
||||
return text.strip(), ""
|
||||
speech = match.group(1).strip().strip('"').strip("'")
|
||||
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||
return plan, speech
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for PI052's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
whether the step fires. Two trigger flavours cover all the cadences
|
||||
the canonical recipe needs:
|
||||
|
||||
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||
memory update; user interjection → plan refresh; user VQA query →
|
||||
vqa answer; pending tool call → dispatcher)
|
||||
|
||||
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||
every step shares a single time source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tick:
|
||||
"""Single tick from :class:`TickClock`. Carries time references the
|
||||
runtime steps consume to gate themselves."""
|
||||
|
||||
index: int
|
||||
"""Monotonic counter — increments by one per tick."""
|
||||
|
||||
monotonic_seconds: float
|
||||
"""``time.monotonic()`` at the start of this tick."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickClock:
|
||||
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||
|
||||
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||
"""
|
||||
|
||||
max_rate_hz: float = 50.0
|
||||
_index: int = field(default=0, init=False)
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def advance(self) -> Tick:
|
||||
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||
now = time.monotonic()
|
||||
if self._last_seconds is not None:
|
||||
sleep_for = (self._last_seconds + period) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
now = time.monotonic()
|
||||
self._last_seconds = now
|
||||
self._index += 1
|
||||
return Tick(index=self._index, monotonic_seconds=now)
|
||||
|
||||
|
||||
class Trigger(Protocol):
|
||||
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class HzTrigger:
|
||||
"""Fire at most ``hz`` times per second.
|
||||
|
||||
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
|
||||
when the action queue is non-empty) and wants the trigger to
|
||||
retry next tick instead of waiting a full period can call
|
||||
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
|
||||
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
|
||||
brief queue-empty window and the step never fires at all.
|
||||
"""
|
||||
|
||||
hz: float
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
period = 1.0 / max(self.hz, 1e-6)
|
||||
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||
self._last_seconds = tick.monotonic_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
def rearm(self) -> None:
|
||||
"""Mark the trigger as not having fired, so the next tick re-evaluates.
|
||||
|
||||
Used by a step that decided to skip after ``should_fire`` already
|
||||
committed the firing — keeps the cadence honest without losing
|
||||
the slot.
|
||||
"""
|
||||
self._last_seconds = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventTrigger:
|
||||
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||
|
||||
The runtime fills ``events_this_tick`` once per tick from:
|
||||
|
||||
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||
``stop``)
|
||||
* internal state transitions (``subtask_change``,
|
||||
``tool_call_pending``)
|
||||
|
||||
The list is consumed (cleared at the end of the tick) so events
|
||||
fire at most once.
|
||||
"""
|
||||
|
||||
event_name: str
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
return self.event_name in events
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the PI052 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||
|
||||
┌── State ──────────────────────────────────────────┐
|
||||
│ task please clean up the kitchen │
|
||||
│ subtask grasp the handle of the sponge │
|
||||
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||
│ memory sponge picked up; counter still dirty │
|
||||
└───────────────────────────────────────────────────┘
|
||||
> _
|
||||
|
||||
The state panel re-renders on every state change. Chat lines are
|
||||
``console.print``'d above the live region so they accumulate naturally
|
||||
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||
panel for cursor position.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try: # rich is optional; only required for the interactive REPL.
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
_HAS_RICH = True
|
||||
except ImportError: # pragma: no cover
|
||||
_HAS_RICH = False
|
||||
Console = Any # type: ignore[assignment]
|
||||
Panel = Any # type: ignore[assignment]
|
||||
Table = Any # type: ignore[assignment]
|
||||
Text = Any # type: ignore[assignment]
|
||||
|
||||
|
||||
_STATE_KEYS = (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
)
|
||||
|
||||
|
||||
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
"""Render the persistent state panel for the live region.
|
||||
|
||||
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||
``Live.update(panel)`` whenever the state changes.
|
||||
"""
|
||||
if not _HAS_RICH:
|
||||
raise RuntimeError(
|
||||
"rich is required for the interactive REPL. "
|
||||
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||
)
|
||||
table = Table.grid(padding=(0, 2), expand=True)
|
||||
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||
table.add_column(justify="left")
|
||||
for key, label in _STATE_KEYS:
|
||||
value = state.get(key)
|
||||
if value is None:
|
||||
rendered = Text("(not set)", style="dim italic")
|
||||
else:
|
||||
rendered = Text(str(value), style="bold")
|
||||
table.add_row(label, rendered)
|
||||
queue = state.get("action_queue")
|
||||
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
footer = Text.assemble(
|
||||
("queued actions: ", "dim"),
|
||||
(str(queue_len), "bold cyan"),
|
||||
(" pending tool calls: ", "dim"),
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
run_mode = state.get("mode", "action")
|
||||
mode_tag = (
|
||||
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
|
||||
)
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
"""Append a user-typed line to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
print(f"you: {line}", flush=True)
|
||||
return
|
||||
console.print(f"[bold cyan]you:[/] {line}")
|
||||
|
||||
|
||||
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||
"""Append robot/runtime log lines to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
for line in lines:
|
||||
print(f"robot: {line.lstrip()}", flush=True)
|
||||
return
|
||||
for line in lines:
|
||||
# The runtime uses leading whitespace + "label: text"; render
|
||||
# the label in green and the value in default for readability.
|
||||
stripped = line.lstrip()
|
||||
if ":" in stripped:
|
||||
label, _, value = stripped.partition(":")
|
||||
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||
else:
|
||||
console.print(f"[bold green]robot:[/] {stripped}")
|
||||
@@ -1,423 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interactive VQA for the PI052 runtime.
|
||||
|
||||
In ``/vlm`` mode a typed line is treated as a VQA question. This module
|
||||
runs the full interactive flow:
|
||||
|
||||
1. pull the current observation and list available cameras,
|
||||
2. ask the operator which camera to ground the question on,
|
||||
3. generate the answer with the VLM conditioned on that one camera,
|
||||
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
|
||||
point (``keypoint``), draw the overlay on the camera frame, save a
|
||||
PNG to ``./vqa_overlays/`` and auto-open it.
|
||||
|
||||
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
|
||||
(see ``lerobot.annotations.steerable_pipeline.validator``):
|
||||
|
||||
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}, ...]}``
|
||||
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
|
||||
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_PREFIX = "observation.images."
|
||||
|
||||
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
|
||||
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
|
||||
# normalized to the image axis) instead of pixel-coordinate JSON, so the
|
||||
# answer string the runtime parses can be e.g.
|
||||
# ``<loc0512><loc0301> blue cube`` (point) or
|
||||
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
|
||||
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
|
||||
|
||||
# Iteration order for shape matching — most specific keys first so an
|
||||
# answer is classified deterministically.
|
||||
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
_BBOX_COLOR = (255, 64, 64)
|
||||
_POINT_COLOR = (64, 220, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def available_cameras(observation: dict | None) -> list[str]:
|
||||
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
|
||||
if not observation:
|
||||
return []
|
||||
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
|
||||
|
||||
|
||||
def camera_short_name(camera_key: str) -> str:
|
||||
"""Strip the ``observation.images.`` prefix for display."""
|
||||
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
|
||||
|
||||
|
||||
def prompt_camera_choice(
|
||||
cameras: list[str],
|
||||
*,
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> str | None:
|
||||
"""Ask the operator which camera frame to draw a VQA overlay on.
|
||||
|
||||
Accepts either the menu number or the (short or full) camera name.
|
||||
A single-camera setup auto-selects without prompting. Returns the
|
||||
chosen ``observation.images.*`` key, or ``None`` if the operator
|
||||
cancels / gives an invalid answer.
|
||||
"""
|
||||
if not cameras:
|
||||
return None
|
||||
if len(cameras) == 1:
|
||||
return cameras[0]
|
||||
print_fn("Draw the result on which camera?")
|
||||
for i, cam in enumerate(cameras, 1):
|
||||
print_fn(f" [{i}] {camera_short_name(cam)}")
|
||||
try:
|
||||
raw = str(input_fn("camera> ")).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
if not raw:
|
||||
return cameras[0]
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
return cameras[idx] if 0 <= idx < len(cameras) else None
|
||||
for cam in cameras:
|
||||
if raw == cam or raw == camera_short_name(cam):
|
||||
return cam
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _loc_to_norm(idx: int) -> float:
|
||||
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
|
||||
return max(0.0, min(1023.0, float(idx))) / 1023.0
|
||||
|
||||
|
||||
def parse_loc_answer(answer: str) -> dict | None:
|
||||
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
|
||||
|
||||
PI052 trains spatial answers in PaliGemma's native detection
|
||||
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
|
||||
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
|
||||
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
|
||||
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
|
||||
— this parser strips loc tokens and treats the remainder as the
|
||||
label, so order is irrelevant. Coordinates come back *normalized*
|
||||
([0, 1]); the overlay denormalizes them against the chosen camera
|
||||
frame's pixel size.
|
||||
|
||||
Returns ``{"kind", "payload", "normalized": True}`` on success
|
||||
(``payload`` mirrors the JSON shapes so the overlay code is shared),
|
||||
or ``None`` when the answer carries no ``<loc>`` tokens.
|
||||
"""
|
||||
if not answer or "<loc" not in answer:
|
||||
return None
|
||||
segments = [seg for seg in answer.split(";") if "<loc" in seg]
|
||||
points: list[tuple[float, float, str]] = []
|
||||
boxes: list[tuple[float, float, float, float, str]] = []
|
||||
for seg in segments:
|
||||
locs = [int(m) for m in _LOC_RE.findall(seg)]
|
||||
label = _LOC_RE.sub("", seg).strip()
|
||||
if len(locs) == 2:
|
||||
y, x = (_loc_to_norm(v) for v in locs[:2])
|
||||
points.append((x, y, label))
|
||||
elif len(locs) >= 4:
|
||||
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
|
||||
boxes.append((x1, y1, x2, y2, label))
|
||||
if boxes:
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
|
||||
for (x1, y1, x2, y2, lbl) in boxes
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
if len(points) == 1:
|
||||
x, y, lbl = points[0]
|
||||
return {
|
||||
"kind": "keypoint",
|
||||
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
|
||||
"normalized": True,
|
||||
}
|
||||
if points: # several bare points → treat as detections-as-points
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
return None
|
||||
|
||||
|
||||
def parse_vqa_answer(answer: str) -> dict | None:
|
||||
"""Parse a VQA answer string into ``{"kind", "payload"}``.
|
||||
|
||||
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
|
||||
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
|
||||
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
|
||||
spatial answers are detected first (PI052 trains them in that native
|
||||
format). Returns ``None`` when the answer is neither ``<loc>`` text
|
||||
nor a parseable JSON object.
|
||||
"""
|
||||
if not answer or not answer.strip():
|
||||
return None
|
||||
loc_parsed = parse_loc_answer(answer)
|
||||
if loc_parsed is not None:
|
||||
return loc_parsed
|
||||
try:
|
||||
payload = json.loads(answer)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
|
||||
VQA_ANSWER_SHAPES,
|
||||
)
|
||||
|
||||
shapes = VQA_ANSWER_SHAPES
|
||||
except ImportError: # pragma: no cover - annotation extra not installed
|
||||
shapes = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
keys = set(payload)
|
||||
for kind in _SHAPE_ORDER:
|
||||
required = shapes.get(kind)
|
||||
if required and required <= keys:
|
||||
return {"kind": kind, "payload": payload}
|
||||
return {"kind": "unknown", "payload": payload}
|
||||
|
||||
|
||||
def answer_has_overlay(parsed: dict | None) -> bool:
|
||||
"""True iff ``parsed`` carries drawable spatial coordinates."""
|
||||
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def observation_image_to_pil(image_tensor: Any) -> Any:
|
||||
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
|
||||
|
||||
The runtime observation stores images as ``(1, C, H, W)`` (or
|
||||
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
|
||||
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
|
||||
the float→uint8 scaling.
|
||||
"""
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
|
||||
|
||||
arr = image_tensor
|
||||
if hasattr(arr, "detach"):
|
||||
arr = arr.detach().cpu()
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.numpy()
|
||||
while arr.ndim > 3: # drop leading batch dim(s)
|
||||
arr = arr[0]
|
||||
return image_array_to_pil_image(arr).convert("RGB")
|
||||
|
||||
|
||||
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
|
||||
|
||||
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
|
||||
``unknown``) are returned as an unmodified copy. When ``parsed`` has
|
||||
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
|
||||
coordinates are scaled to the image's pixel size.
|
||||
"""
|
||||
from PIL import ImageDraw # noqa: PLC0415
|
||||
|
||||
img = image.convert("RGB").copy()
|
||||
kind = parsed.get("kind")
|
||||
payload = parsed.get("payload") or {}
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
|
||||
|
||||
if kind == "bbox":
|
||||
for det in payload.get("detections") or []:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
x1, x2 = x1 * sx, x2 * sx
|
||||
y1, y2 = y1 * sy, y2 * sy
|
||||
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
|
||||
label = str(det.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
|
||||
elif kind == "keypoint":
|
||||
point = payload.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2:
|
||||
try:
|
||||
x, y = float(point[0]) * sx, float(point[1]) * sy
|
||||
except (TypeError, ValueError):
|
||||
return img
|
||||
r = 6
|
||||
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
|
||||
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
|
||||
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
|
||||
label = str(payload.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
def _open_file(path: Path) -> None:
|
||||
"""Best-effort open ``path`` in the OS default viewer."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["open", str(path)], check=False)
|
||||
elif sys.platform.startswith("linux"):
|
||||
subprocess.run(["xdg-open", str(path)], check=False)
|
||||
elif os.name == "nt":
|
||||
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
|
||||
else: # pragma: no cover - exotic platform
|
||||
webbrowser.open(path.resolve().as_uri())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not auto-open %s: %s", path, exc)
|
||||
|
||||
|
||||
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
|
||||
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
|
||||
out = Path(out_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
path = out / f"vqa_{int(time.time() * 1000)}.png"
|
||||
image.save(path)
|
||||
_open_file(path)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def handle_vqa_query(
|
||||
*,
|
||||
policy: Any,
|
||||
observation_provider: Any,
|
||||
question: str,
|
||||
state: dict[str, Any],
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> None:
|
||||
"""Run one interactive VQA question end to end.
|
||||
|
||||
Called synchronously from the input layer while the runtime is in
|
||||
``/question`` mode (the action loop is gated off, so the policy is
|
||||
not in concurrent use). Progress is reported via both
|
||||
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
|
||||
stdout) — in autonomous question mode the panel redraw is suspended,
|
||||
so the direct print is what the operator actually sees.
|
||||
"""
|
||||
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
|
||||
|
||||
def report(line: str) -> None:
|
||||
"""Surface a line both to the panel scrollback and to stdout."""
|
||||
push_log(state, line)
|
||||
try:
|
||||
print_fn(line)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if policy is None or not hasattr(policy, "select_message"):
|
||||
report(" [warn] vqa: policy has no select_message — skipping")
|
||||
return
|
||||
|
||||
observation: dict | None = None
|
||||
if observation_provider is not None:
|
||||
try:
|
||||
observation = observation_provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s", exc)
|
||||
|
||||
# Feed the FULL observation (every camera + state) to the VLM. The
|
||||
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
|
||||
# stripped before tokenization — the actual frames reach the model
|
||||
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
|
||||
# consumes *all* ``config.image_features`` regardless of which
|
||||
# camera the sub-recipe was tagged for. So the model always sees
|
||||
# every camera; the operator never has to name one to ask.
|
||||
answer = _generate_with_policy(
|
||||
policy,
|
||||
_msgs_for_vqa(question),
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
if not answer:
|
||||
report(" [info] vqa gen returned empty")
|
||||
return
|
||||
report(f" vqa: {answer}")
|
||||
|
||||
parsed = parse_vqa_answer(answer)
|
||||
if not answer_has_overlay(parsed):
|
||||
if parsed is None:
|
||||
report(" [info] vqa answer is not JSON — no overlay")
|
||||
return
|
||||
|
||||
# The answer carries a bounding box / point. Its pixel coordinates
|
||||
# are camera-specific and the text answer doesn't say which camera,
|
||||
# so ask the operator *now* — only when there is actually something
|
||||
# to draw — which camera frame to render the overlay on.
|
||||
cameras = available_cameras(observation)
|
||||
if observation is None or not cameras:
|
||||
report(" [info] no camera image — cannot draw overlay")
|
||||
return
|
||||
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
|
||||
if chosen is None:
|
||||
report(" [info] overlay skipped — no camera selected")
|
||||
return
|
||||
try:
|
||||
pil = observation_image_to_pil(observation[chosen])
|
||||
overlay = draw_vqa_overlay(pil, parsed)
|
||||
path = save_and_open_overlay(overlay)
|
||||
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,198 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 pre/post-processor factory.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
relative-action prep (inherited from π0.5)
|
||||
NormalizerProcessorStep
|
||||
RenderMessagesStep — recipe → messages, target_message_indices,
|
||||
message_streams (PR 1 of the steerable
|
||||
stack)
|
||||
PI052TextTokenizerStep — messages → input_ids + label mask +
|
||||
predict_actions
|
||||
DeviceProcessorStep
|
||||
|
||||
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
|
||||
so unannotated datasets keep working.
|
||||
|
||||
Post-processor is unchanged from π0.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
# RenderMessagesStep is intentionally not re-exported from
|
||||
# ``lerobot.processor`` because it pulls in optional language-stack deps;
|
||||
# import it directly.
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
|
||||
def make_pi052_pre_post_processors(
|
||||
config: PI052Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build PI0.5-v2's pre/post-processor pipelines.
|
||||
|
||||
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
PI052TextTokenizerStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
]
|
||||
|
||||
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
|
||||
# Only inserted when explicitly enabled — keeps the post-training-
|
||||
# style recipe (flow + text) as the default. When on, the step
|
||||
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||
if getattr(config, "enable_fast_action_loss", False):
|
||||
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
|
||||
# tokenizer on this dataset's action distribution rather than
|
||||
# using the universal codebook off the shelf. We do this once
|
||||
# and cache to disk, keyed on (dataset, base, n_samples).
|
||||
action_tokenizer_path = config.action_tokenizer_name
|
||||
if (
|
||||
getattr(config, "auto_fit_fast_tokenizer", False)
|
||||
and dataset_repo_id is not None
|
||||
):
|
||||
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
|
||||
|
||||
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
|
||||
try:
|
||||
action_tokenizer_path = fit_fast_tokenizer(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
base_tokenizer_name=config.action_tokenizer_name,
|
||||
n_samples=config.fast_tokenizer_fit_samples,
|
||||
chunk_size=config.chunk_size,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"FAST tokenizer fit failed (%s) — falling back to "
|
||||
"the universal base tokenizer %r. Train will still "
|
||||
"work but compression will be suboptimal.",
|
||||
exc, config.action_tokenizer_name,
|
||||
)
|
||||
|
||||
input_steps.append(
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=action_tokenizer_path,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AbsoluteActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
relative_step=relative_step,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -1,641 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 text-tokenisation step.
|
||||
|
||||
PaliGemma is *not* chat-pretrained, so we can't lean on
|
||||
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
|
||||
messages as plain text with simple ``User: ... Assistant: ...`` role
|
||||
delimiters — matching the prompt format π0.5 uses in the paper
|
||||
(``Task: ... State: ... Action: ...``).
|
||||
|
||||
Outputs:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the
|
||||
concatenated prompt tokenised by the PaliGemma tokenizer (the same
|
||||
one ``processor_pi05`` already uses).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
positions belonging to messages whose index is in
|
||||
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
|
||||
those positions via the PaliGemma ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discretize_state_str(state_row: Any) -> str:
|
||||
"""Discretize a single normalized state vector into 256 bins, space-joined.
|
||||
|
||||
Mirrors pi05's ``Pi05PrepareStateTokenizerProcessorStep`` (same bins /
|
||||
convention) so pi052's low-level action prompt carries proprioception in
|
||||
the exact format pi05 was trained on. Expects state already normalized by
|
||||
the upstream ``NormalizerProcessorStep``.
|
||||
"""
|
||||
arr = state_row.detach().cpu().numpy() if hasattr(state_row, "detach") else np.asarray(state_row)
|
||||
disc = np.digitize(arr, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
return " ".join(str(int(x)) for x in disc.reshape(-1).tolist())
|
||||
|
||||
|
||||
def _state_row_at(state_all: Any, pos: int) -> Any:
|
||||
"""Select the per-sample state row from a (possibly batched) state tensor."""
|
||||
if state_all is None:
|
||||
return None
|
||||
if hasattr(state_all, "ndim") and state_all.ndim >= 2:
|
||||
return state_all[pos]
|
||||
return state_all
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b["text"]
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
|
||||
]
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
|
||||
|
||||
PaliGemma's flat text prompt has no notion of structured tool calls,
|
||||
and ``_format_messages`` only reads ``role`` / ``content`` — so
|
||||
without this a ``say`` tool call is dropped entirely and never
|
||||
supervised. Rewriting it into the content text as a ``<say>...</say>``
|
||||
marker lets the LM head learn to emit it; the runtime parses it back
|
||||
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
|
||||
returned unchanged (the structured calls, if any, are still dropped).
|
||||
"""
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
return message
|
||||
say_texts: list[str] = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
fn = call.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
args = json.loads(args)
|
||||
except (ValueError, TypeError):
|
||||
args = {}
|
||||
text = args.get("text", "") if isinstance(args, dict) else ""
|
||||
if text:
|
||||
say_texts.append(str(text))
|
||||
new = dict(message)
|
||||
new.pop("tool_calls", None)
|
||||
if not say_texts:
|
||||
return new
|
||||
base = _content_to_text(new.get("content")).strip()
|
||||
marker = "".join(f"<say>{t}</say>" for t in say_texts)
|
||||
new["content"] = f"{base}\n{marker}" if base else marker
|
||||
return new
|
||||
|
||||
|
||||
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalise a message's content to a plain string.
|
||||
|
||||
The recipe renderer can emit ``content`` as a string OR as a list
|
||||
of HF-style multimodal blocks (``{type: text, text: ...}``,
|
||||
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
|
||||
only consume strings, so we flatten: drop image blocks (cameras
|
||||
flow through ``observation.images.*`` separately) and join text
|
||||
block texts.
|
||||
"""
|
||||
new = dict(message)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
content = new.get("content")
|
||||
if content is None:
|
||||
new["content"] = ""
|
||||
elif isinstance(content, str):
|
||||
pass
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
t = block.get("text", "")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
new["content"] = "\n".join(parts)
|
||||
else:
|
||||
new["content"] = str(content)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
|
||||
#
|
||||
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
|
||||
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
|
||||
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
|
||||
# **0–1000 normalized coordinates**, NOT pixels. (Verified empirically
|
||||
# on the published datasets: x and y both span 0..1000 with ~30% of
|
||||
# values exceeding the camera's pixel dimensions — they're not pixels.)
|
||||
# Converting to ``<loc>`` is therefore camera-resolution-independent:
|
||||
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
|
||||
# not in the dataset — so the dataset keeps the raw JSON and stays
|
||||
# backbone-agnostic.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The 0–1000 scale Qwen2.5-VL emits for grounding coordinates.
|
||||
_VQA_COORD_SCALE = 1000.0
|
||||
|
||||
|
||||
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
|
||||
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
|
||||
|
||||
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
|
||||
(detection / pointing) tokens, but the *stock* tokenizer does NOT
|
||||
match them when encoding raw text — it BPE-splits ``<loc0162>`` into
|
||||
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
|
||||
the LM head on a ``<loc>`` target then supervises those 7 generic
|
||||
BPE pieces instead of one detection-vocab id, the LM head learns to
|
||||
emit the *character sequence*, and those pieces' logits dominate
|
||||
other turns (the ``<loc>``-salad on subtasks). Registering the loc
|
||||
tokens once makes them tokenize as their single ids (256000+idx),
|
||||
leveraging PaliGemma's detection prior properly. Idempotent.
|
||||
"""
|
||||
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
return tokenizer
|
||||
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
|
||||
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
|
||||
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
|
||||
return f"<loc{max(0, min(1023, idx)):04d}>"
|
||||
|
||||
|
||||
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
|
||||
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
|
||||
|
||||
Input coordinates are in Qwen2.5-VL's 0–1000 normalized space (see
|
||||
module-level note). y is emitted before x for each coordinate pair
|
||||
(PaliGemma convention), with the integer indices in [0, 1023].
|
||||
|
||||
**Format: label first, locs after.** PaliGemma's pretraining puts
|
||||
locs first (``<loc><loc> label``), but for our small-dataset VQA
|
||||
blend that turns the LM head into a loc-emission attractor at every
|
||||
``Assistant:`` position — VQA targets share their first supervised
|
||||
token with ~25% of all text samples, and the head collapses to
|
||||
emitting ``<loc>`` regardless of the prompt. Putting the label
|
||||
first (``label <locY><locX>``) means every text sample (subtask,
|
||||
memory, VQA, …) starts the supervised target with a real word,
|
||||
breaking the attractor. The model still learns the loc vocabulary
|
||||
for the *spatial* portion of the answer; it just can't fire it as
|
||||
the first generation step from a clean prompt.
|
||||
|
||||
Returns ``None`` for non-spatial answers (count / attribute /
|
||||
spatial-relation) — those keep their JSON form.
|
||||
"""
|
||||
point = answer.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
label = str(answer.get("label", "")).strip()
|
||||
if not label:
|
||||
return None
|
||||
return f"{label} {_loc_token(y)}{_loc_token(x)}"
|
||||
|
||||
detections = answer.get("detections")
|
||||
if isinstance(detections, list) and detections:
|
||||
parts: list[str] = []
|
||||
for det in detections:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
label = str(det.get("label", "")).strip()
|
||||
if not label:
|
||||
continue
|
||||
toks = (
|
||||
f"{_loc_token(y1)}{_loc_token(x1)}"
|
||||
f"{_loc_token(y2)}{_loc_token(x2)}"
|
||||
)
|
||||
parts.append(f"{label} {toks}")
|
||||
return " ; ".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
|
||||
def _messages_vqa_to_loc(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
|
||||
|
||||
Each target turn whose content parses as a spatial VQA answer is
|
||||
converted. Non-spatial answers and subtask / memory targets (plain
|
||||
text → not JSON) are left untouched. Camera-independent: VQA coords
|
||||
are 0–1000 normalized, so no observation lookup is needed.
|
||||
"""
|
||||
if not target_indices:
|
||||
return messages
|
||||
out = list(messages)
|
||||
for idx in target_indices:
|
||||
if not (0 <= idx < len(out)):
|
||||
continue
|
||||
content = out[idx].get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
try:
|
||||
answer = json.loads(content)
|
||||
except (ValueError, TypeError):
|
||||
continue # subtask / memory targets are plain text — skip
|
||||
if not isinstance(answer, dict):
|
||||
continue
|
||||
loc_text = _vqa_answer_to_loc(answer)
|
||||
if loc_text is not None:
|
||||
out[idx] = {**out[idx], "content": loc_text}
|
||||
return out
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int] | None = None,
|
||||
eos_token: str | None = None,
|
||||
) -> tuple[str, list[tuple[int, int]]]:
|
||||
"""Concatenate messages into the π0.5-style flat prompt.
|
||||
|
||||
When both ``target_indices`` and ``eos_token`` are given, the EOS
|
||||
string is appended to each supervised target turn's content and the
|
||||
returned span covers it — so the label builder marks the EOS token
|
||||
as a supervised label. That teaches the LM head where the answer
|
||||
*ends*: without an EOS in the target span the model is never given a
|
||||
stop signal and rambles to ``max_length`` at inference. Inference
|
||||
callers omit both args (no EOS baked into the prompt — the model
|
||||
generates it and ``select_message`` stops on it).
|
||||
|
||||
Returns:
|
||||
prompt: the full text the tokenizer will consume.
|
||||
msg_spans: list of ``(char_start, char_end)`` covering each
|
||||
message's supervised payload (content, plus the
|
||||
appended EOS for target turns) within ``prompt``.
|
||||
"""
|
||||
targets = set(target_indices or [])
|
||||
parts: list[str] = []
|
||||
spans: list[tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for i, m in enumerate(messages):
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "") or ""
|
||||
# Role tag + newline. The model has to learn to emit the same
|
||||
# role tokens at generation time, which is fine for greedy
|
||||
# decoding because the chat template is implicit in the
|
||||
# supervised target span.
|
||||
header = f"{role.capitalize()}: "
|
||||
# A supervised target turn ends with EOS so the model learns to
|
||||
# terminate; the span below covers content + EOS. Non-target
|
||||
# turns (and inference) carry no EOS.
|
||||
body = content + eos_token if (eos_token and i in targets) else content
|
||||
# span covers the content (+ EOS) portion only — never the role
|
||||
# tag — so labels are computed over the supervised payload.
|
||||
full = header + body + "\n"
|
||||
start = cursor + len(header)
|
||||
end = start + len(body)
|
||||
parts.append(full)
|
||||
spans.append((start, end))
|
||||
cursor += len(full)
|
||||
return "".join(parts), spans
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
|
||||
class PI052TextTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
No chat template; concatenates messages as
|
||||
``User: ... \\nAssistant: ...`` text.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_length: int = 200
|
||||
padding: str = "max_length"
|
||||
padding_side: str = "right"
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._tokenizer: Any = None
|
||||
|
||||
def _ensure_tokenizer(self) -> Any:
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
self._tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
)
|
||||
return self._tokenizer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline step
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
transition = transition.copy()
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
messages = complementary.get("messages") or []
|
||||
|
||||
if not messages:
|
||||
# No recipe was rendered — caller will fall back to the
|
||||
# plain Pi0.5 prompt path. We pass the transition through
|
||||
# unmodified.
|
||||
return transition
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
# Normalized proprioceptive state (set by NormalizerProcessorStep, which
|
||||
# runs before this step). Injected into low-level action prompts so the
|
||||
# action expert sees proprioception, matching pi05's discretized State:.
|
||||
state_all = (transition.get(TransitionKey.OBSERVATION) or {}).get(OBS_STATE)
|
||||
# VQA coords are 0–1000 normalized (Qwen2.5-VL convention) — the
|
||||
# <loc> conversion is camera-resolution-independent and needs no
|
||||
# observation lookup here.
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(complementary.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
list(tgt_indices),
|
||||
complementary,
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
state_row=_state_row_at(state_all, pos),
|
||||
)
|
||||
for pos, (msg, streams, tgt_indices, s_idx) in enumerate(
|
||||
zip(
|
||||
messages,
|
||||
complementary.get("message_streams") or [[] for _ in messages],
|
||||
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(complementary.get("message_streams") or []),
|
||||
list(complementary.get("target_message_indices") or []),
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
state_row=_state_row_at(state_all, 0),
|
||||
)
|
||||
]
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
||||
**complementary,
|
||||
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
|
||||
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
|
||||
}
|
||||
return transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
state_row: Any = None,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
# after removal.
|
||||
if (
|
||||
self.plan_dropout_prob
|
||||
or self.memory_dropout_prob
|
||||
or self.subtask_dropout_prob
|
||||
or self.interjection_dropout_prob
|
||||
):
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages,
|
||||
target_indices,
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
|
||||
# Rewrite bbox / keypoint VQA target answers from JSON to
|
||||
# PaliGemma <loc> text. Coords are 0–1000 normalized so this is
|
||||
# camera-independent.
|
||||
messages = _messages_vqa_to_loc(messages, target_indices)
|
||||
|
||||
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
||||
# stripping, so the spoken reply is actually tokenized and
|
||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||
# Low-level (action-conditioning) samples get the discretized state
|
||||
# appended to their user message, mirroring pi05's
|
||||
# "..., State: {256-bin};" so the action expert sees proprioception.
|
||||
# Higher-level text streams (subtask/memory generation) stay state-free.
|
||||
if state_row is not None and any(s == "low_level" for s in message_streams):
|
||||
state_str = discretize_state_str(state_row)
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
base = _content_to_text(m.get("content", ""))
|
||||
m["content"] = f"{base}, State: {state_str};"
|
||||
break
|
||||
# Append EOS to supervised target turns so the LM head learns to
|
||||
# stop (the span covers it → it becomes a supervised label).
|
||||
prompt, spans = _format_messages(
|
||||
messages, target_indices, getattr(tokenizer, "eos_token", None)
|
||||
)
|
||||
|
||||
encoded = tokenizer(
|
||||
prompt,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_offsets_mapping=True,
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"][0]
|
||||
attention_mask = encoded["attention_mask"][0].bool()
|
||||
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
|
||||
|
||||
# Build label mask: -100 everywhere except over supervised
|
||||
# target message char ranges.
|
||||
labels = torch.full_like(input_ids, fill_value=-100)
|
||||
for idx in target_indices:
|
||||
if idx >= len(spans):
|
||||
continue
|
||||
char_start, char_end = spans[idx]
|
||||
for token_pos in range(input_ids.shape[0]):
|
||||
if not attention_mask[token_pos]:
|
||||
continue
|
||||
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
|
||||
if tok_end <= char_start or tok_start >= char_end:
|
||||
continue
|
||||
labels[token_pos] = input_ids[token_pos]
|
||||
|
||||
# Scan ALL message streams (not just targets): the
|
||||
# ``low_level_execution`` recipe drops ``target: true`` on
|
||||
# the assistant to avoid trivial copy-from-user text-CE; the
|
||||
# flow loss still needs to fire, gated by ``stream: low_level``.
|
||||
predict_actions = torch.tensor(
|
||||
bool(any(s == "low_level" for s in message_streams)),
|
||||
dtype=torch.bool,
|
||||
)
|
||||
return input_ids, attention_mask, labels, predict_actions, prompt
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-component prompt dropout (Pi0.7 §V.E)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Drop messages classified as plan/memory/subtask context.
|
||||
|
||||
Targets are *never* dropped (they're the supervised payload).
|
||||
Re-maps target_indices to the new positions after drops.
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed = self.dropout_seed
|
||||
if seed is None:
|
||||
# Canonical row-index key set by ``BatchProcessor`` /
|
||||
# ``render_messages_processor``. Falling back to other
|
||||
# keys silently gave every sample seed=0 → identical
|
||||
# dropout pattern across the whole epoch.
|
||||
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
|
||||
try:
|
||||
if hasattr(seed_src, "item"):
|
||||
seed_src = seed_src.item()
|
||||
seed = int(seed_src)
|
||||
except (TypeError, ValueError):
|
||||
seed = 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
keep_indices: list[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if idx in target_indices:
|
||||
keep_indices.append(idx)
|
||||
continue
|
||||
kind = _classify_for_dropout(msg)
|
||||
prob = {
|
||||
"plan": self.plan_dropout_prob,
|
||||
"memory": self.memory_dropout_prob,
|
||||
"subtask": self.subtask_dropout_prob,
|
||||
"interjection": self.interjection_dropout_prob,
|
||||
}.get(kind, 0.0)
|
||||
if prob > 0.0 and rng.random() < prob:
|
||||
continue
|
||||
keep_indices.append(idx)
|
||||
|
||||
# Build remap and apply
|
||||
new_messages = [messages[i] for i in keep_indices]
|
||||
old_to_new = {old: new for new, old in enumerate(keep_indices)}
|
||||
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
|
||||
return new_messages, new_targets
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
content = " ".join(text_parts)
|
||||
elif content is None:
|
||||
return None
|
||||
elif not isinstance(content, str):
|
||||
return None
|
||||
s = content.strip()
|
||||
if s.startswith("Plan:") or s.startswith("Previous plan"):
|
||||
return "plan"
|
||||
if s.startswith("Memory:") or s.startswith("Previous memory"):
|
||||
return "memory"
|
||||
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
@@ -14,28 +14,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Default PaliGemma SigLIP input resolution. Mirrors
|
||||
# ``pi05.configuration_pi05.DEFAULT_IMAGE_SIZE``; duplicated as a plain constant
|
||||
# to avoid importing the pi05 package here (which would create an import cycle:
|
||||
# pi_gemma -> pi05.__init__ -> modeling_pi05 -> pi_gemma).
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
GemmaAttention,
|
||||
GemmaConfig,
|
||||
@@ -59,8 +49,6 @@ else:
|
||||
GradientCheckpointingLayer = None
|
||||
BaseModelOutputWithPast = None
|
||||
create_causal_mask = None
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
|
||||
|
||||
def _gated_residual(
|
||||
@@ -287,8 +275,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
if causal_mask is not None and torch.is_floating_point(causal_mask):
|
||||
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
@@ -381,374 +367,3 @@ __all__ = [
|
||||
"PaliGemmaModelWithPiGemma",
|
||||
"PaliGemmaForConditionalGenerationWithPiGemma",
|
||||
]
|
||||
|
||||
|
||||
# PI0.5 / PI052 dual-expert backbone: generic PaliGemma + Gemma action-expert
|
||||
# transformer machinery used by the pi052 policy. GemmaVariantConfig is openpi's
|
||||
# width/depth variant config (renamed from GemmaConfig to avoid clashing with
|
||||
# transformers' GemmaConfig).
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
|
||||
``torch.nn.functional.scaled_dot_product_attention``.
|
||||
|
||||
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
|
||||
bias masks (the FA backend only accepts causal/sliding-window). On
|
||||
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
|
||||
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
|
||||
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
|
||||
"""
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
if attention_mask is not None and attention_mask.dtype != query.dtype:
|
||||
attention_mask = attention_mask.to(dtype=query.dtype)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout if module.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
att_output, _ = sdpa_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
return outputs_embeds
|
||||
|
||||
|
||||
class GemmaVariantConfig: # see openpi `gemma.py: Config`
|
||||
"""Configuration for Gemma model variants."""
|
||||
|
||||
def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim):
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.mlp_dim = mlp_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
|
||||
def get_gemma_config(variant: str) -> GemmaVariantConfig: # see openpi `gemma.py: get_config`
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "gemma_300m":
|
||||
return GemmaVariantConfig(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
elif variant == "gemma_2b":
|
||||
return GemmaVariantConfig(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(
|
||||
nn.Module
|
||||
): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi
|
||||
"""PaliGemma model with action expert for PI05."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
freeze_vision_encoder: bool = False,
|
||||
train_expert_only: bool = False,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
self._set_requires_grad()
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
|
||||
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"lm_head",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def _set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
for param in self.paligemma.model.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
for param in self.paligemma.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
if self.freeze_vision_encoder:
|
||||
self.paligemma.model.vision_tower.eval()
|
||||
if self.train_expert_only:
|
||||
self.paligemma.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
|
||||
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
|
||||
# was trained in this regime, so scaling image features here over-scales them ~45x and
|
||||
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
)
|
||||
|
||||
# final norm
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
|
||||
|
||||
@@ -175,6 +175,9 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
|
||||
@@ -50,17 +50,7 @@ class RenderMessagesStep(ProcessorStep):
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return transition
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
@@ -77,147 +67,18 @@ class RenderMessagesStep(ProcessorStep):
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def _call_batch(
|
||||
self,
|
||||
transition: EnvTransition,
|
||||
complementary_data: dict[str, Any],
|
||||
persistent_batch: list,
|
||||
events_batch: list,
|
||||
) -> EnvTransition | None:
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
batch_size = max(len(persistent_batch), len(events_batch))
|
||||
messages: list[list[dict[str, Any]]] = []
|
||||
message_streams: list[list[str | None]] = []
|
||||
target_message_indices: list[list[int]] = []
|
||||
keep_indices: list[int] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
|
||||
events=events_batch[i] if i < len(events_batch) else [],
|
||||
t=_batch_value(timestamp, i),
|
||||
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
|
||||
task=_batch_value(complementary_data.get("task"), i),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_transition = (
|
||||
_select_batch_indices(transition, keep_indices)
|
||||
if len(keep_indices) != batch_size
|
||||
else transition.copy()
|
||||
)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data["messages"] = messages
|
||||
new_complementary_data["message_streams"] = message_streams
|
||||
new_complementary_data["target_message_indices"] = target_message_indices
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def _is_batched_language(value: Any) -> bool:
|
||||
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
|
||||
|
||||
|
||||
def _batch_value(value: Any, index: int) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return value[index]
|
||||
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
|
||||
return _scalar(value[index])
|
||||
return _scalar(value)
|
||||
|
||||
|
||||
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
|
||||
selected = transition.copy()
|
||||
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
|
||||
data = selected.get(key)
|
||||
if isinstance(data, dict):
|
||||
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
|
||||
action = selected.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
selected[TransitionKey.ACTION] = _select_value(action, indices)
|
||||
return selected
|
||||
|
||||
|
||||
def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if isinstance(value, list) and len(value) >= len(indices):
|
||||
return [value[i] for i in indices]
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
|
||||
|
||||
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||
if hasattr(task, "item"):
|
||||
task = task.item()
|
||||
if isinstance(task, list):
|
||||
messages = []
|
||||
message_streams = []
|
||||
target_message_indices = []
|
||||
for t in task:
|
||||
rendered = _fallback_low_level_render(t)
|
||||
if rendered is None:
|
||||
return None
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
return {
|
||||
"messages": messages,
|
||||
"message_streams": message_streams,
|
||||
"target_message_indices": target_message_indices,
|
||||
}
|
||||
if not isinstance(task, str) or not task:
|
||||
return None
|
||||
return {
|
||||
"messages": [{"role": "user", "content": task}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [],
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import torch
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -413,15 +412,14 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# During inference, no action is available, skip tokenization
|
||||
return new_transition
|
||||
|
||||
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||
tokens, mask, code_mask = self._tokenize_action(action)
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
@@ -432,7 +430,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
@@ -461,7 +459,6 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
code_masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
@@ -479,26 +476,19 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
prompt_tokens = torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
)
|
||||
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||
|
||||
code_start = 1 + len(prompt_tokens)
|
||||
code_end = code_start + len(action_code_tokens)
|
||||
# add bos
|
||||
tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
prompt_tokens,
|
||||
action_code_tokens,
|
||||
end_tokens,
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
]
|
||||
)
|
||||
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||
code_mask[code_start:code_end] = True
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -507,49 +497,44 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self.max_action_tokens]
|
||||
code_mask = code_mask[: self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
pad_len = self.max_action_tokens - len(tokens)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||
torch.zeros(
|
||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
||||
),
|
||||
]
|
||||
)
|
||||
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
code_masks_list.append(code_mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
code_masks_batch = code_masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
code_masks_batch = code_masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch, code_masks_batch
|
||||
return tokens_batch, masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _, _ = self._tokenize_action(action)
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -21,8 +21,6 @@ from lerobot.utils.import_utils import make_device_from_device_class
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
@@ -120,7 +118,7 @@ def ensure_safe_goal_position(
|
||||
}
|
||||
|
||||
if warnings_dict:
|
||||
logger.warning(
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f"{pformat(warnings_dict, indent=4)}"
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ from .configs import (
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
@@ -50,7 +49,6 @@ from .inference import (
|
||||
from .strategies import (
|
||||
BaseStrategy,
|
||||
DAggerStrategy,
|
||||
EpisodicStrategy,
|
||||
HighlightStrategy,
|
||||
RolloutStrategy,
|
||||
SentryStrategy,
|
||||
@@ -68,8 +66,6 @@ __all__ = [
|
||||
"HardwareContext",
|
||||
"HighlightStrategy",
|
||||
"HighlightStrategyConfig",
|
||||
"EpisodicStrategy",
|
||||
"EpisodicStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
|
||||
@@ -121,35 +121,6 @@ class DAggerPedalConfig:
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("episodic")
|
||||
@dataclass
|
||||
class EpisodicStrategyConfig(RolloutStrategyConfig):
|
||||
"""Episode-oriented recording that mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Records ``dataset.num_episodes`` episodes of maximum ``dataset.episode_time_s`` each.
|
||||
After each episode, runs ``dataset.reset_time_s`` seconds of reset time.
|
||||
|
||||
Keyboard controls:
|
||||
Right arrow — end current episode or reset phase early
|
||||
Left arrow — discard current episode and re-record
|
||||
Escape — stop recording session
|
||||
|
||||
In between episodes:
|
||||
- if there is no teleop leader, the robot is held at its initial joint positions captured at startup.
|
||||
- else, the robot is moved smoothly to the position of the teleop leader.
|
||||
"""
|
||||
|
||||
# This only applies if there are no teleop leaders specified.
|
||||
# When True (default), moves the robot back to the joint positions captured at startup.
|
||||
# Otherwise, leave the robot in its current position.
|
||||
reset_to_initial_position: bool = True
|
||||
|
||||
# Whether to turn on or off the leader -> follower smooth handover behavior.
|
||||
# When False, fallback to follower -> leader handover.
|
||||
# Note that leader -> follower handover is only supported when the leader has `send_feedback` capability.
|
||||
smooth_leader_to_follower_handover: bool = True
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
@@ -258,13 +229,7 @@ class RolloutConfig:
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy,
|
||||
(
|
||||
SentryStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
EpisodicStrategyConfig,
|
||||
),
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
@@ -28,7 +27,6 @@ __all__ = [
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"EpisodicStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
|
||||
@@ -56,14 +56,10 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
@@ -73,6 +69,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
@@ -174,6 +171,64 @@ class DAggerEvents:
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -701,31 +756,31 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
|
||||
if teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, prev_action)
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, prev_action, target)
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
@@ -735,7 +790,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if teleop_supports_feedback(teleop):
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Episodic rollout strategy: mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
- Policy drives the robot during each recording episode.
|
||||
- An optional teleoperator can drive the robot during reset phases so the
|
||||
operator can bring the environment back to its starting configuration.
|
||||
If no teleop is connected the robot stays in its current position.
|
||||
- Keyboard controls:
|
||||
|
||||
Right arrow — end the current episode or reset phase early
|
||||
Left arrow — discard the current episode and re-record it
|
||||
Escape — stop the recording session
|
||||
|
||||
Dataset naming follows the rollout convention: repo names must start with ``rollout_``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import (
|
||||
follower_smooth_move_to,
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
teleop_smooth_move_to,
|
||||
teleop_supports_feedback,
|
||||
)
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..configs import EpisodicStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodicStrategy(RolloutStrategy):
|
||||
"""Policy-driven multi-episode recording, mirrors the behavior of ``lerobot-record``.
|
||||
|
||||
Each recording episode runs the policy for maximum ``dataset.episode_time_s``
|
||||
seconds, recording every frame. A reset phase of ``dataset.reset_time_s``
|
||||
follows every episode (except the last) so the operator can manually
|
||||
reset the environment. During the reset phase, an optional teleoperator
|
||||
drives the robot; if none is present the robot returns to its initial joint positions captured at startup.
|
||||
|
||||
The policy state (hidden state, RTC queue, interpolator) is reset at
|
||||
the start of each recording episode.
|
||||
|
||||
Keyboard events:
|
||||
right arrow → end current episode or reset phase early
|
||||
left arrow → discard & re-record current episode
|
||||
ESC → stop the session
|
||||
"""
|
||||
|
||||
config: EpisodicStrategyConfig
|
||||
|
||||
def __init__(self, config: EpisodicStrategyConfig) -> None:
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._events: dict | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Start the inference engine and attach the keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
self._listener, self._events = init_keyboard_listener()
|
||||
logger.info("Episodic strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main multi-episode recording loop."""
|
||||
cfg = ctx.runtime.cfg
|
||||
dataset_cfg = cfg.dataset
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
fps = cfg.fps
|
||||
episode_time_s = dataset_cfg.episode_time_s
|
||||
reset_time_s = dataset_cfg.reset_time_s
|
||||
num_episodes = dataset_cfg.num_episodes
|
||||
single_task = dataset_cfg.single_task or cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
display_compressed = (
|
||||
True
|
||||
if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None)
|
||||
else cfg.display_compressed_images
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < num_episodes and not events["stop_recording"]:
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
# Reset policy state at episode start (discard leftover hidden state / queue)
|
||||
self._engine.reset()
|
||||
self._interpolator.reset()
|
||||
self._engine.resume()
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
self._policy_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
events=events,
|
||||
features=features,
|
||||
fps=fps,
|
||||
control_time_s=episode_time_s,
|
||||
dataset=dataset,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
# Reset phase, skip after the last episode (but run when re-recording)
|
||||
if not events["stop_recording"] and (
|
||||
recorded_episodes < num_episodes - 1 or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
|
||||
if teleop:
|
||||
# Smooth handover so the transition to teleop control is jerk-free.
|
||||
# For actuated teleops: drive the leader arm to the follower's current
|
||||
# position so the operator takes over without fighting the arm.
|
||||
# For non-actuated teleops: slide the follower to the teleop's current
|
||||
# pose instead, since the leader cannot be driven.
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
if (
|
||||
teleop_supports_feedback(teleop)
|
||||
and self.config.smooth_leader_to_follower_handover
|
||||
):
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
teleop_smooth_move_to(teleop, current_pos, duration_s=2)
|
||||
teleop.disable_torque()
|
||||
else:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
follower_smooth_move_to(robot, current_pos, target, duration_s=1)
|
||||
|
||||
elif self.config.reset_to_initial_position:
|
||||
# No teleop: return the robot to its startup position.
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
self._reset_loop(
|
||||
ctx=ctx,
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
events=events,
|
||||
fps=fps,
|
||||
control_time_s=reset_time_s,
|
||||
display_data=cfg.display_data,
|
||||
display_compressed=display_compressed,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# returns to its initial joint positions captured at startup
|
||||
if not teleop and self.config.reset_to_initial_position:
|
||||
self._return_to_initial_position(hw=ctx.hardware, duration_s=1)
|
||||
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
finally:
|
||||
# Save any frames buffered in the current episode so an unexpected
|
||||
# exception or KeyboardInterrupt does not silently drop recorded data.
|
||||
# suppress: save_episode raises if the buffer is empty (nothing to lose).
|
||||
logger.info("Episodic control loop ended — saving any in-progress episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def _policy_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
events: dict,
|
||||
features: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
dataset,
|
||||
single_task: str,
|
||||
) -> None:
|
||||
"""Policy-driven recording loop for a single episode."""
|
||||
interpolator = self._interpolator
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
|
||||
if self._handle_warmup(ctx.runtime.cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
dataset.add_frame({**obs_frame, **action_frame, "task": single_task})
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
if sleep_t < 0:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({fps} Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long "
|
||||
"3) CPU starvation"
|
||||
)
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def _reset_loop(
|
||||
self,
|
||||
ctx: RolloutContext,
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: float,
|
||||
control_time_s: float,
|
||||
display_data: bool,
|
||||
display_compressed: bool,
|
||||
) -> None:
|
||||
"""Reset-phase loop: teleop drives the robot if available, no recording."""
|
||||
processors = ctx.processors
|
||||
control_interval = 1.0 / fps
|
||||
|
||||
timestamp = 0.0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if ctx.runtime.shutdown_event.is_set():
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
if teleop is not None:
|
||||
act = teleop.get_action()
|
||||
act_teleop = processors.teleop_action_processor((act, obs))
|
||||
robot_action = processors.robot_action_processor((act_teleop, obs))
|
||||
robot.send_action(robot_action)
|
||||
|
||||
if display_data:
|
||||
obs_processed = processors.robot_observation_processor(obs)
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=act_teleop,
|
||||
compress_images=display_compressed,
|
||||
)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_t = control_interval - dt
|
||||
precise_sleep(max(sleep_t, 0.0))
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Finalise dataset, stop listener, push to hub, and disconnect hardware."""
|
||||
cfg = ctx.runtime.cfg
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
|
||||
if not is_headless() and self._listener is not None:
|
||||
self._listener.stop()
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
|
||||
if (
|
||||
cfg.dataset is not None
|
||||
and cfg.dataset.push_to_hub
|
||||
and ctx.data.dataset is not None
|
||||
and safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=cfg.dataset.tags,
|
||||
private=cfg.dataset.private,
|
||||
)
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=cfg.return_to_initial_position,
|
||||
)
|
||||
log_say("Exiting", play_sounds)
|
||||
logger.info("Episodic strategy teardown complete")
|
||||
@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .episodic import EpisodicStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
@@ -43,8 +42,4 @@ def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
if config.type == "episodic":
|
||||
return EpisodicStrategy(config)
|
||||
raise ValueError(
|
||||
f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger, episodic"
|
||||
)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||
``language_events`` columns on a LeRobot dataset.
|
||||
|
||||
Annotations live directly in ``data/chunk-*/file-*.parquet``.
|
||||
|
||||
Example:
|
||||
|
||||
uv run lerobot-annotate \\
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
For distributed runs, see ``examples/annotations/run_hf_job.py``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs import parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||
if cfg.root is not None:
|
||||
return Path(cfg.root)
|
||||
if cfg.repo_id is not None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||
raise ValueError("Either --root or --repo_id must be provided.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Run the steerable annotation pipeline against a dataset."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
root = _resolve_root(cfg)
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend)
|
||||
# Surface the resolved cameras up front so a silent vqa-module no-op
|
||||
# is obvious in job output rather than discovered post-hoc by counting
|
||||
# parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.vqa.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: the vqa module is enabled but no cameras were "
|
||||
"resolved — it will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
plan = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.plan, frame_provider=frame_provider)
|
||||
interjections = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.interjections, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
vqa = GeneralVqaModule(vlm=vlm, config=cfg.vqa, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
plan=plan,
|
||||
interjections=interjections,
|
||||
vqa=vqa,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
summary = executor.run(root)
|
||||
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||
for phase in summary.phases:
|
||||
logger.info(
|
||||
"annotate: phase=%s processed=%d skipped=%d",
|
||||
phase.name,
|
||||
phase.episodes_processed,
|
||||
phase.episodes_skipped,
|
||||
)
|
||||
if summary.validation_report.warnings:
|
||||
for w in summary.validation_report.warnings:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
if cfg.repo_id is None and cfg.new_repo_id is None:
|
||||
raise ValueError(
|
||||
"--push_to_hub requires --repo_id or --new_repo_id (the dataset repo to push to)."
|
||||
)
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hub.
|
||||
|
||||
Pushes to ``cfg.new_repo_id`` when set, otherwise back to ``cfg.repo_id``.
|
||||
"""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.new_repo_id or cfg.repo_id
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
private=cfg.push_private,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||
commit_info = api.upload_folder(
|
||||
folder_path=str(root),
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}",
|
||||
flush=True,
|
||||
)
|
||||
revision = getattr(commit_info, "oid", None)
|
||||
tag_kwargs = {
|
||||
"repo_id": repo_id,
|
||||
"tag": version_tag,
|
||||
"repo_type": "dataset",
|
||||
"exist_ok": True,
|
||||
}
|
||||
if revision is not None:
|
||||
tag_kwargs["revision"] = revision
|
||||
|
||||
try:
|
||||
api.create_tag(**tag_kwargs)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -95,67 +95,6 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _wrap_text_to_width(text: str, cv2, font, scale: int, thickness: int, max_width: int) -> list[str]:
|
||||
"""Greedy word-wrap using measured pixel width so text fits the frame."""
|
||||
words = text.split()
|
||||
lines: list[str] = []
|
||||
current = ""
|
||||
for word in words:
|
||||
candidate = f"{current} {word}".strip()
|
||||
(w, _), _ = cv2.getTextSize(candidate, font, scale, thickness)
|
||||
if w > max_width and current:
|
||||
lines.append(current)
|
||||
current = word
|
||||
else:
|
||||
current = candidate
|
||||
if current:
|
||||
lines.append(current)
|
||||
return lines or [""]
|
||||
|
||||
|
||||
def _annotate_eval_frames(
|
||||
frames: np.ndarray, task: str | None, subtask: str | None
|
||||
) -> np.ndarray:
|
||||
"""Overlay the high-level task and predicted subtask onto rendered frames.
|
||||
|
||||
``frames`` is ``(n_envs, H, W, C)`` uint8. Best-effort: if OpenCV isn't
|
||||
available the frames are returned unchanged so eval never fails over a
|
||||
visualization concern.
|
||||
"""
|
||||
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||
return frames
|
||||
try:
|
||||
import cv2 # noqa: PLC0415
|
||||
except ImportError:
|
||||
return frames
|
||||
|
||||
width = frames.shape[2]
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
scale = 0.5
|
||||
margin = 6
|
||||
max_width = width - 2 * margin
|
||||
|
||||
lines: list[str] = []
|
||||
if task:
|
||||
lines += _wrap_text_to_width(f"Task: {task}", cv2, font, scale, 1, max_width)
|
||||
if subtask:
|
||||
lines += _wrap_text_to_width(f"Subtask: {subtask}", cv2, font, scale, 1, max_width)
|
||||
if not lines:
|
||||
return frames
|
||||
|
||||
out = frames.copy()
|
||||
for i in range(out.shape[0]):
|
||||
img = np.ascontiguousarray(out[i])
|
||||
y = 18
|
||||
for line in lines:
|
||||
# Black outline then white fill so text stays legible on any scene.
|
||||
cv2.putText(img, line, (margin, y), font, scale, (0, 0, 0), 3, cv2.LINE_AA)
|
||||
cv2.putText(img, line, (margin, y), font, scale, (255, 255, 255), 1, cv2.LINE_AA)
|
||||
y += 20
|
||||
out[i] = img
|
||||
return out
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -386,42 +325,11 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
frames = np.stack([env.envs[i].render() for i in range(n_to_render_now)]) # noqa: B023
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
elif hasattr(env, "call"):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
# Covers AsyncVectorEnv and _LazyAsyncVectorEnv (which wraps one).
|
||||
frames = np.stack(env.call("render")[:n_to_render_now])
|
||||
else:
|
||||
return
|
||||
|
||||
# Overlay the high-level task and (for hierarchical policies like
|
||||
# pi052) the predicted low-level subtask onto each frame. Both are
|
||||
# best-effort: missing values just skip that line.
|
||||
try:
|
||||
tasks = list(env.call("task_description"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
try:
|
||||
tasks = list(env.call("task"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
tasks = None
|
||||
# Per-env subtasks when available (batched hierarchical policies);
|
||||
# fall back to the scalar last_subtask for single-env / other policies.
|
||||
subtasks = getattr(policy, "last_subtasks", None)
|
||||
subtask_scalar = getattr(policy, "last_subtask", None)
|
||||
annotated = []
|
||||
for i in range(frames.shape[0]):
|
||||
if subtasks is not None and i < len(subtasks):
|
||||
subtask_i = subtasks[i]
|
||||
else:
|
||||
subtask_i = subtask_scalar
|
||||
annotated.append(
|
||||
_annotate_eval_frames(
|
||||
frames[i : i + 1],
|
||||
tasks[i] if tasks is not None and i < len(tasks) else None,
|
||||
subtask_i,
|
||||
)[0]
|
||||
)
|
||||
ep_frames.append(np.stack(annotated))
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,6 @@ Strategies
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
--strategy.type=episodic Episode-oriented recording with reset phases
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
@@ -112,18 +111,6 @@ Usage examples
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Episodic mode — episode-oriented recording with reset phases
|
||||
lerobot-rollout \\
|
||||
--strategy.type=episodic \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/rollout_episodic_data \\
|
||||
--dataset.num_episodes=20 \\
|
||||
--dataset.single_task="Grab the cube"
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
|
||||
@@ -20,7 +20,6 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
@@ -44,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -162,170 +161,6 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _print_debug_text_predictions(
|
||||
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
|
||||
) -> None:
|
||||
"""Forward the current batch and print head-argmax vs label per supervised position.
|
||||
|
||||
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
|
||||
policy types that expose ``debug_text_predictions`` participate
|
||||
(currently PI052); others are silently skipped. Pretty-prints up to
|
||||
``n_samples`` samples from the current batch, showing the prompt,
|
||||
every supervised position's (label, prediction, ✓/✗), and a
|
||||
per-sample token-accuracy summary — the cheapest "is text training
|
||||
actually learning anything" signal.
|
||||
"""
|
||||
# Accelerator/DDP wraps the policy in a ``module`` attribute and
|
||||
# doesn't proxy custom methods through, so a naive
|
||||
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
|
||||
# wrapper — and the helper would silently no-op. Walk through any
|
||||
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
|
||||
# to reach the raw policy that actually defines the method.
|
||||
inner = policy
|
||||
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
|
||||
inner = inner.module
|
||||
if not hasattr(inner, "debug_text_predictions"):
|
||||
logging.warning(
|
||||
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
|
||||
"debug_text_predictions method — skipping dump.",
|
||||
type(inner).__name__,
|
||||
)
|
||||
return
|
||||
try:
|
||||
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
|
||||
return
|
||||
if not debug:
|
||||
logging.warning(
|
||||
"debug_text_predictions returned no supervised samples — "
|
||||
"current batch has no text labels."
|
||||
)
|
||||
return
|
||||
policy = inner # used below for select_message-style decoding parity
|
||||
|
||||
# Build a tokenizer for decoding — match training side exactly.
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("debug preds: tokenizer load failed: %s", exc)
|
||||
return
|
||||
|
||||
ids = debug["input_ids"]
|
||||
labels = debug["labels"]
|
||||
preds = debug["predictions"]
|
||||
attn = debug["attention_mask"]
|
||||
|
||||
n = ids.shape[0]
|
||||
print(
|
||||
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
|
||||
flush=True,
|
||||
)
|
||||
for s in range(n):
|
||||
a = attn[s].tolist()
|
||||
real = sum(a)
|
||||
sid = ids[s].tolist()
|
||||
sl = labels[s].tolist()
|
||||
sp = preds[s].tolist()
|
||||
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
|
||||
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
|
||||
print(f" prompt: {prompt!r}", flush=True)
|
||||
|
||||
# Ground-truth target (the contiguous supervised label span).
|
||||
sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100]
|
||||
if sup_ids:
|
||||
print(
|
||||
f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Training-side teacher-forced argmax on the same prompt+target.
|
||||
n_sup = n_ok = 0
|
||||
teacher_chars: list[int] = []
|
||||
for i in range(1, real):
|
||||
label = sl[i]
|
||||
if label == -100:
|
||||
continue
|
||||
n_sup += 1
|
||||
pred = int(sp[i - 1])
|
||||
teacher_chars.append(pred)
|
||||
if label == pred:
|
||||
n_ok += 1
|
||||
teacher_text = (
|
||||
tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else ""
|
||||
)
|
||||
acc = n_ok / max(n_sup, 1)
|
||||
print(
|
||||
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
|
||||
flush=True,
|
||||
)
|
||||
print("=" * 60 + "\n", flush=True)
|
||||
|
||||
|
||||
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
|
||||
"""Build per-frame sampling weights that oversample VQA-annotated frames.
|
||||
|
||||
Scans the dataset's ``language_events`` column for frames carrying a
|
||||
``vqa``-style annotation and returns a weight tensor (length == total
|
||||
dataset frames) such that, under multinomial sampling, VQA frames make up
|
||||
roughly ``target_fraction`` of the training stream.
|
||||
|
||||
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
|
||||
frames cannot be detected or there are none.
|
||||
"""
|
||||
if not 0.0 < target_fraction < 1.0:
|
||||
logging.warning(
|
||||
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
|
||||
target_fraction,
|
||||
)
|
||||
return None
|
||||
hf = getattr(dataset, "hf_dataset", None)
|
||||
if hf is None or "language_events" not in getattr(hf, "column_names", []):
|
||||
logging.warning(
|
||||
"Dataset has no `language_events` column — VQA oversampling disabled."
|
||||
)
|
||||
return None
|
||||
|
||||
events_col = hf["language_events"]
|
||||
n_frames = len(events_col)
|
||||
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
|
||||
for i, rows in enumerate(events_col):
|
||||
if rows and any((row or {}).get("style") == "vqa" for row in rows):
|
||||
is_vqa[i] = True
|
||||
|
||||
n_vqa = int(is_vqa.sum())
|
||||
if n_vqa == 0:
|
||||
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
|
||||
return None
|
||||
n_other = n_frames - n_vqa
|
||||
|
||||
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
|
||||
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
|
||||
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
|
||||
weight = max(weight, 1.0)
|
||||
weights = torch.ones(n_frames, dtype=torch.double)
|
||||
weights[is_vqa] = weight
|
||||
logging.info(
|
||||
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
|
||||
"weighting them x%.2f to target ~%.0f%% of the training stream.",
|
||||
n_vqa,
|
||||
n_frames,
|
||||
100.0 * n_vqa / n_frames,
|
||||
weight,
|
||||
100.0 * target_fraction,
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
"""
|
||||
@@ -355,26 +190,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from datetime import timedelta
|
||||
|
||||
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Bump the c10d store-get / barrier timeout so the rank-0-only
|
||||
# ``make_dataset`` block below doesn't trigger a barrier crash on
|
||||
# large datasets. Default is 10 min (``store->get`` 600 s); a
|
||||
# 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``)
|
||||
# spends >13 min on rank 0 building the episode/frame index
|
||||
# while ranks 1-N idle at ``wait_for_everyone()`` and crash with
|
||||
# ``DistBackendError: ... wait timeout after 600000ms``. 2 h is
|
||||
# plenty of headroom; fast paths are unaffected.
|
||||
ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2))
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs, ipg_kwargs],
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
cpu=force_cpu,
|
||||
)
|
||||
|
||||
@@ -468,27 +292,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
# pi052: even when loading pretrained weights, build the processors
|
||||
# from the current pi052 config so the recipe text-label and FAST
|
||||
# action-label steps are generated and not silently swapped for the
|
||||
# checkpoint's older processor stack.
|
||||
if cfg.policy.type == "pi052" and processor_pretrained_path is not None and not cfg.resume:
|
||||
logging.warning(
|
||||
"pi052 is loading pretrained weights from %s, but building processors from the current "
|
||||
"pi052 config so recipe text labels and FAST action labels are generated.",
|
||||
processor_pretrained_path,
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
processor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
@@ -497,14 +300,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if cfg.is_reward_model_training:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
# For pi052 (and any future policy that auto-fits part of its
|
||||
# preprocessing per-dataset), pass the dataset repo id so the
|
||||
# processor factory can locate/refresh dataset-specific artifacts
|
||||
# (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64],
|
||||
# π0.5 §III.C).
|
||||
if cfg.policy.type == "pi052":
|
||||
processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": device.type},
|
||||
@@ -591,29 +386,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
from_indices = dataset.meta.episodes["dataset_from_index"]
|
||||
to_indices = dataset.meta.episodes["dataset_to_index"]
|
||||
# When `vqa_target_fraction` is set, oversample VQA-annotated
|
||||
# frames via a weighted sampler; otherwise plain episode-aware.
|
||||
vqa_weights = None
|
||||
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
|
||||
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
|
||||
if vqa_weights is not None:
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
vqa_weights,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
)
|
||||
else:
|
||||
sampler = EpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -644,54 +423,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
policy.train()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EMA setup
|
||||
# ------------------------------------------------------------------
|
||||
# Shadow copy of the trainable params for late-training averaging
|
||||
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
|
||||
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
|
||||
# LeRobot main both skip it). Off by default; opt in with
|
||||
# ``--ema.enable=true``. Implemented via ema-pytorch
|
||||
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
|
||||
# EMA library, also used by lucidrains' diffusion repos.
|
||||
ema = None
|
||||
if cfg.ema.enable and is_main_process:
|
||||
from ema_pytorch import EMA # noqa: PLC0415
|
||||
|
||||
ema = EMA(
|
||||
accelerator.unwrap_model(policy),
|
||||
beta=cfg.ema.decay,
|
||||
update_after_step=cfg.ema.warmup_steps,
|
||||
update_every=1, # update on every ema.update() call
|
||||
# Don't register the live model as an ema submodule — accelerator
|
||||
# already owns its lifecycle, and double-registration would
|
||||
# double-count its params in ``ema.state_dict()``.
|
||||
include_online_model=False,
|
||||
)
|
||||
ema.to(accelerator.device)
|
||||
logging.info(
|
||||
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
|
||||
"use_for_eval=%s, use_for_wandb_examples=%s",
|
||||
cfg.ema.decay,
|
||||
cfg.ema.warmup_steps,
|
||||
cfg.ema.use_for_eval,
|
||||
cfg.ema.use_for_wandb_examples,
|
||||
)
|
||||
|
||||
# Resume the EMA shadow if a previous run wrote one.
|
||||
if cfg.checkpoint_path is not None:
|
||||
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
|
||||
if ema_path.exists():
|
||||
logging.info("Resuming EMA shadow from %s", ema_path)
|
||||
try:
|
||||
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"Failed to load EMA shadow (%s) — restarting EMA from "
|
||||
"current live weights",
|
||||
exc,
|
||||
)
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
@@ -744,14 +475,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# EMA update: pull one step of the live weights into the shadow.
|
||||
# Runs only on the main process (the shadow lives there); other
|
||||
# ranks rely on the live model staying in sync via accelerator.
|
||||
# ``ema-pytorch`` holds an internal reference to the online model
|
||||
# (set at construction), so ``ema.update()`` takes no args.
|
||||
if ema is not None:
|
||||
ema.update()
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
@@ -762,27 +485,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
# Optional periodic head-prediction dump for the LM head:
|
||||
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
|
||||
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
|
||||
# whether the text head is actually learning what we expect, vs
|
||||
# collapsing to a fixed token. Refilling the recipe-sample dump
|
||||
# budget at the same cadence also redumps the raw input shapes.
|
||||
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
|
||||
if (
|
||||
_debug_preds_every > 0
|
||||
and step % _debug_preds_every == 0
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
|
||||
|
||||
_tp._DUMPED_SO_FAR = 0
|
||||
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
_print_debug_text_predictions(policy, batch, step, n_samples=5)
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
@@ -793,49 +495,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# EMA observability: ``ema.step`` is the count of
|
||||
# ``ema.update()`` calls (= optimizer steps once EMA is
|
||||
# enabled); ``ema.initted`` flips to True once we've
|
||||
# crossed ``update_after_step``.
|
||||
if ema is not None:
|
||||
wandb_log_dict["ema/step"] = int(ema.step.item())
|
||||
wandb_log_dict["ema/initted"] = float(ema.initted.item())
|
||||
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
# Periodic training-example dump to wandb (camera images + text
|
||||
# fields + action endpoints). Opt-in via ``--wandb.log_examples_freq``;
|
||||
# independent of ``--log_freq`` so you can keep scalar logs frequent
|
||||
# and the heavier visual dump rare (e.g. every 5000 steps).
|
||||
if (
|
||||
wandb_logger is not None
|
||||
and cfg.wandb.log_examples_freq > 0
|
||||
and step % cfg.wandb.log_examples_freq == 0
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
# Optionally use the EMA shadow model directly for the
|
||||
# predicted-action columns (matches what eval / deployment
|
||||
# would see). ``ema-pytorch`` exposes the shadow as a
|
||||
# full ``nn.Module`` at ``ema.ema_model``, so we just
|
||||
# pass that instead of swap-and-restore.
|
||||
target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_wandb_examples)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=target_policy,
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("wandb log_training_examples failed: %s", exc)
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
@@ -851,18 +513,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
# Save the EMA shadow alongside the training state so a
|
||||
# resumed run picks up exactly where the live EMA left off.
|
||||
# ``ema-pytorch.state_dict()`` returns the full shadow
|
||||
# nn.Module's state dict + step/initted buffers; saved as
|
||||
# .pt (the rest of training_state mixes formats already).
|
||||
if ema is not None:
|
||||
try:
|
||||
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
|
||||
ema_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ema.state_dict(), ema_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("Failed to save EMA shadow: %s", exc)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -872,20 +522,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
# Use the EMA shadow model for eval when enabled —
|
||||
# standard practice for diffusion-style policies (~1–3%
|
||||
# lift on closed-loop success). ``ema.ema_model`` is a
|
||||
# full nn.Module clone, so we just pass it through; no
|
||||
# swap/restore on the live policy needed.
|
||||
eval_target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_eval)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=eval_target_policy,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""LeRobot tool implementations.
|
||||
|
||||
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||
implementations one file per tool, plus the registry that maps tool
|
||||
names to classes.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
|
||||
from .base import Tool
|
||||
from .registry import TOOL_REGISTRY, get_tools
|
||||
from .say import SayTool
|
||||
|
||||
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||
|
||||
Tools are the executable side of the OpenAI-style function-calling
|
||||
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||
messages: the schema describes *what can be called*, the tool
|
||||
implementation describes *how to call it*.
|
||||
|
||||
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||
``say.py`` for ``SayTool``) and are registered in
|
||||
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||
heavy dependencies (torch models, audio backends, network clients,
|
||||
hardware drivers) only load when the dataset actually declares the tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
"""Minimum surface every tool must expose."""
|
||||
|
||||
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||
name: str
|
||||
|
||||
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||
#: prompt.
|
||||
schema: dict[str, Any]
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute the tool with the model-provided arguments.
|
||||
|
||||
``arguments`` is the parsed dict from
|
||||
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||
when the model emits a JSON-string by the chat-template
|
||||
convention). Implementations validate the dict against their own
|
||||
schema; the runtime only routes by name.
|
||||
|
||||
Return value is implementation-defined — typically a tensor
|
||||
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||
``None`` (side-effect-only call).
|
||||
"""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user