mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
102 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 965d42825f | |||
| 1238a0cd47 | |||
| 53c7641885 | |||
| 088c8371df | |||
| 3a52a18b0e | |||
| dad2cf1178 | |||
| bce5387e04 | |||
| c8ce413d73 | |||
| 82dffde7fa | |||
| eaf0218bc8 | |||
| a0e52d52fe | |||
| 85576acc29 | |||
| e7e5fca5de | |||
| beb22afd81 | |||
| e99c55af4b | |||
| 408e0ca763 | |||
| d55b581ca1 | |||
| 24d2ffe3c6 | |||
| 789f29aa56 | |||
| a356b12c41 | |||
| e8327b8e62 | |||
| c450298147 | |||
| 5c30b14929 | |||
| ce24063efd | |||
| 82934719db | |||
| 401a217597 | |||
| 40094b0464 | |||
| 8fa8323c91 | |||
| fdbfc015a2 | |||
| 73740ecf4b | |||
| 1b81e49214 | |||
| d813c75b76 | |||
| 3434d2ef22 | |||
| b71e10da6b | |||
| 0f6e3230df | |||
| 2f2e42c4aa | |||
| 5ee0104739 | |||
| e064cfcb04 | |||
| b3d9494831 | |||
| 1217fdb6f0 | |||
| d0388e1142 | |||
| 524aa59faa | |||
| 27f7829b09 | |||
| 7f8bf108e8 | |||
| 855ff027f8 | |||
| 3b797bb118 | |||
| aea04721ae | |||
| ab5479129a | |||
| e6d4ac6f02 | |||
| 5722d365c5 | |||
| 3d7e60cee4 | |||
| 7b767d4d60 | |||
| f1e3ab7794 | |||
| 585341ba9f | |||
| 23ff346027 | |||
| 3c5cbe7af4 | |||
| f2cbd97635 | |||
| c06c8d594a | |||
| cd495a3a9d | |||
| c99ac45cd1 | |||
| 13aaafeae0 | |||
| 2129648bf4 | |||
| f5cd3f6e4e | |||
| ecf5766301 | |||
| 11597d4f71 | |||
| 8b9c598cf4 | |||
| b325475b38 | |||
| ef137ff86a | |||
| c5df821a96 | |||
| 7ec3d7999c | |||
| 712d63abbd | |||
| 6653999983 | |||
| 4bdbedc9a0 | |||
| e240305e8e | |||
| ccd189b264 | |||
| ef1242bbd4 | |||
| ebf4a04d41 | |||
| 4419b4ef1b | |||
| ff06ca82d2 | |||
| fcb01e73eb | |||
| 268f8d1f53 | |||
| 663fff0ae2 | |||
| 9d6af804bf | |||
| f763f85213 | |||
| e3e9374e2c | |||
| c1a0c601e2 | |||
| d656da8ccc | |||
| 1ca38d9748 | |||
| 5a6aa64570 | |||
| b5f65e5332 | |||
| cd6b43ea7a | |||
| 2236bbe7a3 | |||
| cb0a944941 | |||
| 8a3d64033f | |||
| 03ee50e08f | |||
| ca87ccd941 | |||
| 77352c495c | |||
| 0b06790da0 | |||
| b43dc39ba4 | |||
| 2b71221194 | |||
| 8833d735a1 | |||
| 05a5223885 |
@@ -382,6 +382,7 @@ jobs:
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -482,6 +483,7 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -693,6 +695,7 @@ jobs:
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -800,6 +803,7 @@ jobs:
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -900,6 +904,8 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--env.episode_length=50 \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
@@ -24,14 +24,14 @@ on:
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
This issue was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
This PR was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 30
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
days-before-pr-close: 30
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,9 +18,8 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
# TODO(Steven): Bump these versions
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install Python, system dependencies, and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
build-essential git curl \
|
||||
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
|
||||
@@ -31,8 +31,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
@@ -47,6 +51,8 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
@@ -61,6 +67,8 @@
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` populates the two language columns introduced by the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — directly into
|
||||
`data/chunk-*/file-*.parquet`.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Three modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
|
||||
| `task_aug` (rephrasings of canonical task) | `language_persistent` | Module 1 |
|
||||
| `interjection` | `language_events` | Module 2 |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
|
||||
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
|
||||
|
||||
The writer does **not** add a `tools` column to the parquet — the tool
|
||||
catalog lives at `meta/info.json["tools"]` instead (see
|
||||
[Tools](./tools)). After every annotation run the pipeline ensures the
|
||||
canonical `say` schema is present in that list, preserving any tools the
|
||||
user pre-declared.
|
||||
|
||||
If you want to declare additional tools for a dataset before annotation
|
||||
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||
anything already there. Implementations of those tools live under
|
||||
`src/lerobot/tools/`; one file per tool, registered via
|
||||
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||
|
||||
## Running locally
|
||||
|
||||
Install the extra and invoke the console script. Episode-level
|
||||
concurrency comes from `--executor.episode_parallelism` (default 16);
|
||||
that is the only knob the in-process executor exposes.
|
||||
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--root=/path/to/dataset \
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every Module 1/2/3 prompt
|
||||
by default, decoded from the dataset's first `observation.images.*`
|
||||
stream. Override with `--vlm.camera_key=observation.images.<name>` to
|
||||
pin a specific viewpoint. Datasets with no video tracks fall back to
|
||||
text-only prompts automatically.
|
||||
|
||||
**Module 1 sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--module_1.max_video_frames` (default 128) only caps the frames packed
|
||||
into the video block as a model-capacity bound. Module 2 attaches a
|
||||
short window of frames around the interjection timestamp; Module 3
|
||||
attaches the exact emission frame to each VQA pair.
|
||||
|
||||
## Running on Hugging Face Jobs
|
||||
|
||||
Distributed annotation is delegated to
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). The repo
|
||||
ships a launcher script you copy and edit for your dataset:
|
||||
|
||||
```bash
|
||||
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
|
||||
```
|
||||
|
||||
[`examples/annotation/run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotation/run_hf_job.py)
|
||||
spawns one `h200x2` job that:
|
||||
|
||||
1. installs the branch under test plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) for the chosen model,
|
||||
3. runs Modules 1 / 2 / 3 across the dataset via `lerobot-annotate`,
|
||||
4. uploads the annotated dataset to `--push_to_hub`.
|
||||
|
||||
To target a different dataset, model, or hub repo, edit the `CMD` block
|
||||
inside the script — every flag in there maps directly onto a CLI flag of
|
||||
`lerobot-annotate` (see `lerobot-annotate --help` for the full list).
|
||||
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline's outputs are designed to be consumed by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)) — typically:
|
||||
|
||||
- low-level / high-level / memory-update branches consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
- An interjection-response branch consumes `interjection` events plus
|
||||
the paired speech atom (merged into one assistant target turn via
|
||||
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||
- A VQA branch consumes the `(vqa, user)` and `(vqa, assistant)` pairs
|
||||
from `language_events`.
|
||||
|
||||
## Why the design splits state from events
|
||||
|
||||
Two things drive the scope:
|
||||
|
||||
1. **Persistent state vs exact-event split.** Persistent rows
|
||||
(`subtask`, `plan`, `memory`) broadcast per episode and answer "what
|
||||
state is in force at this frame?". Event rows (`interjection`, `vqa`,
|
||||
speech) only appear on the exact frame whose timestamp matches the
|
||||
emission. The pipeline writes timestamps taken straight from the
|
||||
source parquet — no floating-point recomputation.
|
||||
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||
(vLLM if available, transformers fallback) so the cost is one model
|
||||
load per dataset, not three.
|
||||
|
||||
## Module independence and staged reruns
|
||||
|
||||
Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
|
||||
Before the writer runs, `StagingValidator` checks:
|
||||
|
||||
- exact frame-timestamp alignment for every event row;
|
||||
- no orphan speech / interjection pairs;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||
- VQA assistant `content` parses as JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row routes to the column dictated by `column_for_style(style)`.
|
||||
|
||||
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable VLA Policies ([Zhao 2025](https://arxiv.org/abs/2509.07626))
|
||||
multi-abstraction grounding. Pi0.7 also grounds answers across
|
||||
multiple abstraction levels.
|
||||
|
||||
Future maintainers should adjust the prompt templates in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||
references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
|
||||
O(`max_interjections_per_episode`) Module 2 calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
`language_events` is empty on most frames and is bounded by the number of
|
||||
emissions, not `num_frames × num_emissions`.
|
||||
|
||||
## Reproducibility via seed and prompt hashes
|
||||
|
||||
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||
timestamps and VQA question types. Combined with the deterministic prompt
|
||||
templates checked into `prompts/`, two runs at the same seed against the
|
||||
same dataset and the same model checkpoint produce byte-identical staging
|
||||
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -0,0 +1,168 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
||||
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
|
||||
## Script
|
||||
|
||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | ---------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=2
|
||||
```
|
||||
|
||||
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
|
||||
For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=20 \
|
||||
--inference.rtc.max_guidance_weight=5.0 \
|
||||
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--strategy.num_episodes=50 \
|
||||
--interpolation_multiplier=3
|
||||
```
|
||||
|
||||
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
||||
|
||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||
|
||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||
|
||||
|
||||
+26
-105
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
```bash
|
||||
lerobot-record \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_awesome_leader_arm \
|
||||
--policy.path=${HF_USER}/my_policy
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Run the policy inference loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
<hfoption id="Sentry mode (with recording)">
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--duration=600
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
The `--strategy.type` flag selects the execution mode:
|
||||
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
# Policy Deployment (lerobot-rollout)
|
||||
|
||||
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||
|
||||
## Quick Start
|
||||
|
||||
No extra dependencies are needed beyond your robot and policy extras.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/act_koch_real \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--task="pick up cube" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
This runs the policy for 30 seconds with no recording.
|
||||
|
||||
---
|
||||
|
||||
## Strategies
|
||||
|
||||
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||
|
||||
### Base (`--strategy.type=base`)
|
||||
|
||||
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the box" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ---------------- | ------------------------------------------------------ |
|
||||
| `--duration` | Run time in seconds (0 = infinite) |
|
||||
| `--task` | Task description passed to the policy |
|
||||
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||
|
||||
### Sentry (`--strategy.type=sentry`)
|
||||
|
||||
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||
|
||||
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/rollout_eval_data \
|
||||
--dataset.single_task="Put lego brick into the box" \
|
||||
--duration=3600
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ----------------------------------------------------------- |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||
|
||||
### Highlight (`--strategy.type=highlight`)
|
||||
|
||||
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=highlight \
|
||||
--strategy.ring_buffer_seconds=30 \
|
||||
--strategy.save_key=s \
|
||||
--strategy.push_key=h \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ------------------ | -------------------------------------------------------- |
|
||||
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||
| `h` (configurable) | Push dataset to Hub |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ---------------------------------------------- |
|
||||
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||
|
||||
### DAgger (`--strategy.type=dagger`)
|
||||
|
||||
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||
|
||||
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||
|
||||
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.record_autonomous=true \
|
||||
--strategy.num_episodes=50 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
|
||||
--dataset.single_task="Grasp the block"
|
||||
```
|
||||
|
||||
**Keyboard controls** (default input device):
|
||||
|
||||
| Key | Action |
|
||||
| ------- | ------------------------------------------- |
|
||||
| `Space` | Pause / resume policy execution |
|
||||
| `Tab` | Start / stop human correction |
|
||||
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------ | ------------------------------------------------------- |
|
||||
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||
|
||||
### Sync (default)
|
||||
|
||||
One policy call per control tick. The main loop blocks until the action is computed.
|
||||
|
||||
Works with all policies. No extra flags needed.
|
||||
|
||||
### Real-Time Chunking (`--inference.type=rtc`)
|
||||
|
||||
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||
|
||||
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--policy.path=${HF_USER}/pi0_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Pick up the cube" \
|
||||
--duration=60 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||
|
||||
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||
|
||||
---
|
||||
|
||||
## Common Flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||
| `--robot.port` | Serial port for the robot | -- |
|
||||
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||
| `--fps` | Control loop frequency | 30 |
|
||||
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||
| `--resume` | Resume a previous recording session | false |
|
||||
| `--play_sounds` | Vocal synthesis for events | true |
|
||||
|
||||
---
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||
|
||||
```python
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=my_robot_config,
|
||||
policy=my_policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=30,
|
||||
duration=60,
|
||||
task="my task",
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=my_custom_action_processor, # optional
|
||||
robot_observation_processor=my_custom_obs_processor, # optional
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
```
|
||||
|
||||
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||
@@ -0,0 +1,147 @@
|
||||
# Language columns and recipes
|
||||
|
||||
Most LeRobot datasets ship with a single `task` string per episode — fine for
|
||||
short, single-instruction skills, but not enough for the longer-horizon,
|
||||
multi-modal robot policies the field is moving toward (high-level planning,
|
||||
memory, interjections, VQA, tool use). To support those policies without
|
||||
forking the dataset format, LeRobot extends `LeRobotDataset` with two optional
|
||||
language columns and a small recipe layer that turns those rows into
|
||||
chat-style training samples on the fly.
|
||||
|
||||
The design splits cleanly into three layers:
|
||||
|
||||
1. **Data in the dataset** — language annotations stored next to frames in
|
||||
`data/chunk-*/file-*.parquet` as two optional columns (`language_persistent`
|
||||
and `language_events`). Datasets without these columns keep their existing
|
||||
behavior.
|
||||
2. **Recipe** — a YAML file that declares which annotation rows to bind and
|
||||
how to lay them out as chat turns (`role`, `content`, optional images,
|
||||
optional tool calls). Recipes are pure config; no Python required to add a
|
||||
new one.
|
||||
3. **Training format** — at sample time, `RenderMessagesStep` resolves the
|
||||
recipe against the per-frame annotations and emits HF-style `messages` plus
|
||||
LeRobot-specific sidecars (`message_streams`, `target_message_indices`)
|
||||
that policy processors consume.
|
||||
|
||||
This page describes each layer in turn.
|
||||
|
||||
## Layer 1 — language columns in the dataset
|
||||
|
||||
The two optional columns live next to frame data in
|
||||
`data/chunk-*/file-*.parquet`:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float64 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||
the matching `observation.images.*` feature key. Rows of every other style —
|
||||
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||
enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
### Architecture
|
||||
|
||||
The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
### View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${vqa}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
if_present: vqa,
|
||||
}
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -61,17 +61,6 @@ lerobot-eval \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
| Goal | What to do |
|
||||
| --------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
|
||||
+7
-3
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
||||
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
|
||||
## Testing RTC with a Real Robot
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
|
||||
# ... create plots
|
||||
```
|
||||
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
+29
-28
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,12 +488,13 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--rabc_kappa=0.03
|
||||
--sample_weighting.kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--rabc_kappa=0.05
|
||||
--sample_weighting.kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--rabc_kappa=0.07
|
||||
--sample_weighting.kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -550,8 +551,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -576,7 +577,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives.
|
||||
2. How the annotation pipeline produces tool-call atoms.
|
||||
3. How to add your own tool.
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak."
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — will live under
|
||||
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||
the canonical `say` implementation (wrapping Kyutai's pocket-tts) are
|
||||
not part of the catalog layer described here; today this layer ships
|
||||
only the schema storage and the `DEFAULT_TOOLS` fallback constant.
|
||||
|
||||
## Per-row tool _invocations_
|
||||
|
||||
The catalog above describes _what can be called_. The actual _call_ — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- {
|
||||
role: assistant,
|
||||
content: "${current_plan}",
|
||||
stream: high_level,
|
||||
target: true,
|
||||
tool_calls_from: speech,
|
||||
}
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text _and_ the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||
> `get_tools(meta)`) which is not part of the catalog layer shipped
|
||||
> today — those modules don't yet exist in the tree. Step 1 alone is
|
||||
> enough to make the tool visible to the chat template via
|
||||
> `meta.tools` so the model can learn to _generate_ the call;
|
||||
> executing the call at inference requires the runtime layer.
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk _before_ running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Short label for the saved image."
|
||||
}
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` pulls `pocket-tts`. Users installing
|
||||
only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py`:
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
If you want to use a tool _without_ writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to _generate_ the call. Steps 2 and 3 are only
|
||||
needed to actually _execute_ it at inference.
|
||||
@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--rtc.enabled=true
|
||||
--inference.type=rtc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs Module 1/2/3 across the dataset (per-camera VQA via PR 3471),
|
||||
4. uploads the annotated dataset to ``--push_to_hub``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
# Mirror lerobot's [annotations] runtime deps. ``openai`` is required
|
||||
# because ``VlmConfig.backend`` defaults to ``"openai"`` (which talks
|
||||
# to a vllm/transformers/ktransformers OpenAI-compatible server).
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include "
|
||||
"toml typing-inspect openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
"--vlm.num_gpus=2 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=256 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--executor.episode_parallelism=32 "
|
||||
"--vlm.chat_template_kwargs='{enable_thinking: false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
"--module_1.frames_per_second=1.0 "
|
||||
"--module_1.use_video_url=true "
|
||||
"--module_1.use_video_url_fps=1.0 "
|
||||
"--module_3.K=1 --module_3.vqa_emission_hz=0.2 "
|
||||
"--push_to_hub=pepijn223/super_poulain_qwen36moe-3"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x2",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
logger.info("[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("[HIL] Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("[HIL] End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("[HIL] Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
logger.info(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
logger.info("[HIL] RESET")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
logger.info("Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
||||
logger.info(
|
||||
"%s\n Controls:\n"
|
||||
" SPACE - Pause policy\n"
|
||||
" c - Take control\n"
|
||||
" p - Resume policy after pause/correction\n"
|
||||
" → - End episode\n"
|
||||
" ESC - Stop and push to hub",
|
||||
mode,
|
||||
)
|
||||
+62
-31
@@ -14,17 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
@@ -83,43 +90,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot
|
||||
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
|
||||
@@ -45,9 +45,6 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -77,6 +74,10 @@ def main():
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||
make_default_processors()
|
||||
)
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
@@ -87,14 +88,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -106,13 +107,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). For a CLI entry point with the same capabilities plus
|
||||
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||
"""
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||
# by ``build_rollout_context`` to reload the full model.
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||
# No custom processors here — LeKiwi runs on raw joint features.
|
||||
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -65,14 +65,15 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
@@ -94,7 +95,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
# Build pipeline to convert EE action to joints action (IK).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -107,7 +108,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
# Build pipeline to convert joint observation to EE observation (FK).
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -118,13 +119,12 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
@@ -163,14 +163,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -182,13 +182,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||
|
||||
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||
with phone teleoperation in EE space, so at deployment we only need the
|
||||
joint↔EE conversion on the robot side; the phone is not used.
|
||||
|
||||
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||
(inline policy call). For recording during rollout, switch to Sentry,
|
||||
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Peek at motor names once to build the kinematic solver.
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,673 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.left_arm_config.max_relative_target=8.0 \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.right_arm_config.max_relative_target=8.0 \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--interpolation_multiplier=3 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--device=cuda
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: Tensor,
|
||||
current_state: Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> Tensor:
|
||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
||||
|
||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
||||
the policy we need to re-express it relative to the *current* robot state
|
||||
and then re-normalize.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
||||
# not the full pos/vel/torque state the robot exposes.
|
||||
observation_features_hw = {
|
||||
key: value
|
||||
for key, value in robot.observation_features().items()
|
||||
if key.endswith(".pos") or isinstance(value, tuple)
|
||||
}
|
||||
|
||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is not None:
|
||||
if relative_step.action_names is None:
|
||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
relative_step.action_names = [
|
||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Re-anchor leftover actions for relative-action policies.
|
||||
# We need the *postprocessed* (absolute) leftover, not the original
|
||||
# (normalized/relative) one that get_left_over() returns.
|
||||
if (
|
||||
prev_actions is not None
|
||||
and relative_step is not None
|
||||
and OBS_STATE in obs_with_policy_features
|
||||
):
|
||||
with action_queue.lock:
|
||||
if action_queue.queue is not None:
|
||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
||||
else:
|
||||
prev_actions_abs = None
|
||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_abs,
|
||||
current_state=obs_with_policy_features[OBS_STATE],
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -143,43 +151,67 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -190,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -62,21 +62,20 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
# Build pipeline to convert follower joints to EE observation.
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -87,7 +86,7 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
# Build pipeline to convert leader joints to EE action.
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -98,9 +97,9 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
steps=[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
@@ -115,13 +114,12 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
@@ -144,7 +142,7 @@ def main():
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
init_rerun(session_name="recording_so100_ee")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -160,14 +158,14 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -179,13 +177,13 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). The custom observation/action processors convert between
|
||||
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||
forward/inverse kinematics.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Policy config (full model is loaded inside build_rollout_context).
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||
# wrong observation/action space.
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||
|
||||
|
||||
def main():
|
||||
@@ -22,10 +22,10 @@ def main():
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
# Make reward model, preprocessor, and optimizer
|
||||
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
loss, output_dict = reward_model.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -58,8 +58,8 @@ def main():
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
# You can now save the trained reward model.
|
||||
reward_model.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+18
-3
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -128,7 +128,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -194,12 +194,25 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). vllm is the preferred backend
|
||||
# on Linux, with a transformers fallback elsewhere; openai is the default
|
||||
# backend and talks to any OpenAI-compatible server (``vllm serve`` /
|
||||
# ``transformers serve`` / hosted endpoints). Distributed execution is
|
||||
# delegated to Hugging Face Jobs (see examples/annotation/run_hf_job.py).
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"openai>=1.40,<2.0",
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -289,10 +302,12 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module1Config:
|
||||
"""Module 1: plan + subtasks + memory + task augmentation.
|
||||
|
||||
Module 1 attaches the whole episode as one Qwen-VL video block;
|
||||
``max_video_frames`` only caps the frames packed in (a model-capacity
|
||||
bound, not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Number of ``task_aug`` rephrasings emitted at ``t=0``. The renderer's
|
||||
# ``${task}`` binding rotates among them per ``sample_idx``. ``0`` disables.
|
||||
n_task_rephrasings: int = 10
|
||||
|
||||
# When to derive the task from the video instead of using
|
||||
# ``record.episode_task``: ``off``, ``if_short`` (short / placeholder /
|
||||
# missing canonical task), or ``always``. The derived task replaces the
|
||||
# canonical one for every Module-1 prompt; ``meta/tasks.parquet`` is
|
||||
# never modified.
|
||||
derive_task_from_video: str = "if_short"
|
||||
derive_task_min_words: int = 3
|
||||
|
||||
# Frame sampling for the subtask-decomposition prompt.
|
||||
frames_per_second: float = 1.0
|
||||
max_video_frames: int = 128
|
||||
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), Module 1 sends a
|
||||
# ``video_url`` block pointing at a per-episode mp4 subclip and lets the
|
||||
# server sample frames at ``use_video_url_fps``.
|
||||
use_video_url: bool = False
|
||||
use_video_url_fps: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module2Config:
|
||||
"""Module 2: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
# Each interjection emits a paired ``(interjection, speech)`` event row
|
||||
# and triggers a ``plan`` refresh at the same timestamp via Module 1.
|
||||
max_interjections_per_episode: int = 3
|
||||
interjection_min_t: float = 2.0
|
||||
|
||||
# Visual context attached to the interjection prompt: a short window
|
||||
# of frames centered on the chosen timestamp so the VLM sees the
|
||||
# ongoing motion rather than a single frozen frame.
|
||||
interjection_window_seconds: float = 2.0
|
||||
interjection_window_frames: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module3Config:
|
||||
"""Module 3: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 3
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
# One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests).
|
||||
# ``openai`` talks to a local OpenAI-compatible server; the CLI
|
||||
# auto-spawns one when ``auto_serve=True``.
|
||||
backend: str = "openai"
|
||||
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
# OpenAI-compatible server endpoint; ``EMPTY`` works for local servers.
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
api_key: str = "EMPTY"
|
||||
|
||||
# When True with ``backend=openai``, the CLI probes ``api_base`` and
|
||||
# spawns a server if none answers (default: ``transformers serve``).
|
||||
# Set to False to fail fast when pointing at a remote endpoint.
|
||||
auto_serve: bool = True
|
||||
serve_port: int = 8000
|
||||
# Override the auto-serve command. ``{port}`` is substituted per replica
|
||||
# when ``parallel_servers > 1``.
|
||||
serve_command: str | None = None
|
||||
|
||||
# Run multiple independent inference servers for round-robin client
|
||||
# routing (each pinned to a GPU via ``CUDA_VISIBLE_DEVICES`` and bound
|
||||
# to ``serve_port + i``). ``num_gpus=0`` means one GPU per replica.
|
||||
parallel_servers: int = 1
|
||||
num_gpus: int = 0
|
||||
client_concurrency: int = 16
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
|
||||
# Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
gpu_memory_utilization: float = 0.9
|
||||
# Cap context length (None = model default). On 80 GB H100 a 30B BF16
|
||||
# model often needs <= 8192 to leave KV-cache headroom.
|
||||
max_model_len: int | None = None
|
||||
trust_remote_code: bool = False
|
||||
|
||||
# Override the camera stream used for keyframe attachment. None picks
|
||||
# the first ``observation.images.*`` key the dataset declares.
|
||||
camera_key: str | None = None
|
||||
# Forwarded as ``extra_body.chat_template_kwargs`` on every chat call;
|
||||
# use to pass model-specific flags such as ``{"enable_thinking": false}``.
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor settings.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); this config only controls
|
||||
intra-process episode concurrency.
|
||||
"""
|
||||
|
||||
# Episodes processed concurrently within each module phase. Each
|
||||
# in-flight episode dispatches 3-5 dependent VLM calls, so this is the
|
||||
# main knob for saturating ``parallel_servers`` and ``client_concurrency``.
|
||||
episode_parallelism: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate``.
|
||||
|
||||
The writer rewrites ``data/chunk-*/file-*.parquet`` in place. Multiple
|
||||
revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
repo_id: str | None = None
|
||||
root: Path | None = None
|
||||
|
||||
# Defaults to ``<root>/.annotate_staging/`` when unset.
|
||||
staging_dir: Path | None = None
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
module_1: Module1Config = field(default_factory=Module1Config)
|
||||
module_2: Module2Config = field(default_factory=Module2Config)
|
||||
module_3: Module3Config = field(default_factory=Module3Config)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
# Upload the annotated dataset to the Hugging Face Hub when set.
|
||||
push_to_hub: str | None = None
|
||||
push_private: bool = False
|
||||
push_commit_message: str | None = None
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""In-process executor that runs the annotation phases.
|
||||
|
||||
The executor plans **six phases** in the dependency order from the plan:
|
||||
|
||||
phase 1: Module 1 (plan + subtasks + memory)
|
||||
phase 2: Module 2 (interjections + speech)
|
||||
phase 3: Module 1 plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: Module 3 (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why Module 1 must be re-entered after Module 2 — to refresh
|
||||
``plan`` rows at interjection timestamps.
|
||||
|
||||
Distributed execution is provided by Hugging Face Jobs (see
|
||||
``examples/annotation/run_hf_job.py``); the runner inside the job
|
||||
invokes ``lerobot-annotate`` which uses this in-process executor.
|
||||
Episode-level concurrency is controlled by
|
||||
``ExecutorConfig.episode_parallelism``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all four phases over a dataset root in-process.
|
||||
|
||||
Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism``
|
||||
(a thread pool); cluster-level concurrency comes from running this
|
||||
executor inside a Hugging Face Job. Tests construct the executor
|
||||
directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
module_1: Any # PlanSubtasksMemoryModule
|
||||
module_2: Any # InterjectionsAndSpeechModule
|
||||
module_3: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
print(f"[annotate] {n} episodes total", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: Module 1 (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
|
||||
# Phase 2: Module 2 (interjections + speech). Module 2 reads
|
||||
# Module 1's subtask rows from the same staging tree to ground
|
||||
# the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
|
||||
# Phase 3: Module 1 plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: Module 3 (VQA)
|
||||
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_annotation_metadata_in_info(root: Path) -> None:
|
||||
"""Write language features and canonical tools to ``meta/info.json``.
|
||||
|
||||
``LanguageColumnsWriter`` adds ``language_persistent`` and
|
||||
``language_events`` to parquet shards. The metadata must advertise
|
||||
those columns too, otherwise non-streaming ``LeRobotDataset`` loads
|
||||
cast against the old schema and fail on the extra parquet columns.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = json.loads(info_path.read_text())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
changed = False
|
||||
|
||||
features = info.get("features")
|
||||
if not isinstance(features, dict):
|
||||
features = {}
|
||||
merged_features = {**features, **language_feature_info()}
|
||||
if merged_features != features:
|
||||
info["features"] = merged_features
|
||||
changed = True
|
||||
|
||||
existing = info.get("tools")
|
||||
if not isinstance(existing, list):
|
||||
existing = []
|
||||
names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)}
|
||||
merged = list(existing)
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
merged.append(SAY_TOOL_SCHEMA)
|
||||
if merged != existing:
|
||||
info["tools"] = merged
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
# Atomic replace — info.json is load-bearing for dataset
|
||||
# metadata, so a crash mid-write would brick the dataset.
|
||||
tmp_info = info_path.with_suffix(info_path.suffix + ".tmp")
|
||||
tmp_info.write_text(json.dumps(info, indent=2))
|
||||
tmp_info.replace(info_path)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
f"tools={[t['function']['name'] for t in merged]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
import time as _time # noqa: PLC0415
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
|
||||
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = _time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = _time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, _time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = _time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
||||
|
||||
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
||||
therefore calls back into Module 1 with the interjection timestamps so
|
||||
Module 1's existing prompt path is reused.
|
||||
"""
|
||||
if not self.module_1.enabled or not self.module_2.enabled:
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row for row in staging.read("module_2") if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.module_1.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
@@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(Module 1, Module 2) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. The returned list is
|
||||
intended to be passed as one ``{"type":"video", "video":<list>}``
|
||||
block to a Qwen-VL-compatible model that pools temporally itself.
|
||||
Empty list if no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for Module 1 (subtask
|
||||
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
# Pipeline runs Module 1/2/3 phases under a ThreadPoolExecutor (see
|
||||
# ``ExecutorConfig.episode_parallelism``); guard the dict cache and the
|
||||
# one-shot warn flag against concurrent updates from worker threads.
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras
|
||||
# (``video_keys`` is video-only). Some datasets declare cameras with
|
||||
# ``dtype=image``, which would otherwise look empty here and silently
|
||||
# disable Module 3 even though the videos are there.
|
||||
keys = list(getattr(self._meta, "camera_keys", None) or self._meta.video_keys or [])
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
with self._lock:
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
with self._lock:
|
||||
for i, img in zip(miss_indices, decoded, strict=False):
|
||||
out[i] = img
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally.
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips
|
||||
re-extract when the cached clip already exists. Re-encodes to
|
||||
H.264 (libx264) so the resulting mp4 is decodable by every
|
||||
downstream video processor — stream-copy would inherit the
|
||||
source codec (often AV1 in modern LeRobot datasets), which
|
||||
vllm's libav build cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if self.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = self._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"])
|
||||
src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=300)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]:
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
return _decode_pyav_direct(video_path, shifted, self.tolerance_s)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time decoding fails so silent
|
||||
# Module-3-no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
with self._lock:
|
||||
already_warned = getattr(self, "_warned_decode_fail", False)
|
||||
if not already_warned:
|
||||
self._warned_decode_fail = True
|
||||
if not already_warned:
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _decode_pyav_direct(video_path: Any, timestamps: list[float], tolerance_s: float) -> list[Any]:
|
||||
"""Decode the requested timestamps from ``video_path`` using PyAV directly.
|
||||
|
||||
Bypasses ``lerobot.datasets.video_utils.decode_video_frames`` entirely
|
||||
because its "pyav" path actually goes through
|
||||
``decode_video_frames_torchvision`` → ``torchvision.io.VideoReader``,
|
||||
which was removed in torchvision >= 0.22 (the vllm/vllm-openai:latest
|
||||
container ships with torchvision 0.25). The annotation pipeline only
|
||||
needs a handful of PIL images per (episode, ts), so we can decode them
|
||||
with PyAV without any torch dependency at all.
|
||||
|
||||
Returns one ``PIL.Image`` per requested timestamp, in the same order.
|
||||
Any timestamp the decoder couldn't reach is silently dropped (mirrors
|
||||
the previous behaviour); callers filter ``None``/missing entries.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
if not timestamps:
|
||||
return []
|
||||
|
||||
targets = sorted(set(timestamps))
|
||||
seek_to = max(0.0, min(targets) - max(0.5, tolerance_s))
|
||||
|
||||
container = av.open(str(video_path))
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
# PyAV needs the seek target in stream timebase ticks.
|
||||
seek_pts = 0 if stream.time_base is None else int(seek_to / float(stream.time_base))
|
||||
try:
|
||||
container.seek(seek_pts, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError:
|
||||
# Some streams reject the explicit seek; fall back to decoding from start.
|
||||
container.seek(0)
|
||||
|
||||
results: dict[float, Any] = {}
|
||||
target_iter = iter(targets)
|
||||
next_target = next(target_iter, None)
|
||||
for frame in container.decode(stream):
|
||||
if next_target is None:
|
||||
break
|
||||
ts = float(frame.pts * frame.time_base) if frame.pts is not None else None
|
||||
if ts is None:
|
||||
continue
|
||||
# Walk past targets we've already overshot — we keep the closest
|
||||
# frame within tolerance.
|
||||
while next_target is not None and ts >= next_target - tolerance_s:
|
||||
if abs(ts - next_target) <= tolerance_s or ts >= next_target:
|
||||
img = frame.to_image() # PIL.Image.Image (RGB)
|
||||
results.setdefault(next_target, img)
|
||||
next_target = next(target_iter, None)
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
return [results[ts] for ts in timestamps if ts in results]
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||
return [{"type": "image", "image": img} for img in images]
|
||||
|
||||
|
||||
def to_video_block(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of PIL images as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not images:
|
||||
return []
|
||||
return [{"type": "video", "video": list(images)}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 3: general VQA at a timed cadence.
|
||||
|
||||
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
||||
emission. For datasets with multiple cameras, every emission tick produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module3Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module3Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Module 3 (VQA) found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras Module 3 should iterate per emission tick.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
def _generate_one(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return self._postprocess(result)
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
|
||||
timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module2Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module2Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull Module 1's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. Module 1 ran first.
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("module_1"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("module_2", rows)
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("module_2_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("module_2_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(0.0, tgt))
|
||||
t = snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -0,0 +1,389 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
Module 2 completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module1Config
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other Module-1 prompt. May be
|
||||
# the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see Module1Config.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# PR 1 renderer rotates ``${task}`` deterministically through them
|
||||
# so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# plan row at t=0
|
||||
plan_text = self._generate_plan(record, subtask_spans, task=effective_task)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(t0),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task)
|
||||
if mem_text:
|
||||
ts = snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("module_1", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives Module 1 for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
return task.lower() in self._PLACEHOLDER_TASKS
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VLM call helpers (factored out: every Module-1 prompt below follows
|
||||
# the same "build messages → single VLM call → pull a named field"
|
||||
# shape, only differing in field name + post-processing).
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any:
|
||||
"""Run a single VLM call and return ``result[field]`` or ``None``.
|
||||
|
||||
Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)``
|
||||
dance every prompt-call site needs.
|
||||
"""
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict):
|
||||
return result.get(field)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _text_message(text: str) -> list[dict[str, Any]]:
|
||||
"""One-shot text-only user message wrapped for ``generate_json``."""
|
||||
return [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
||||
|
||||
def _video_message(self, record: EpisodeRecord, prompt: str) -> list[dict[str, Any]]:
|
||||
"""User message combining the episode video block with ``prompt``."""
|
||||
content = [*self._episode_video_block(record), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
text = self._vlm_field(self._video_message(record, load_prompt("module_1_video_task")), "task")
|
||||
return text.strip() if isinstance(text, str) and text.strip() else None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(base_task=base_task, n=n)
|
||||
raw = self._vlm_field(self._text_message(prompt), "rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)]
|
||||
return [s for s in out if s][:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = self.frame_provider.episode_clip_path(record, cache_dir)
|
||||
return (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(1, int(round(episode_duration * self.config.frames_per_second)))
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections — subtask generation
|
||||
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||
Now also forwards the interjection's own text into the prompt so
|
||||
the refreshed plan can actually reflect the user's correction
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("module_1")
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts, strict=True):
|
||||
t = snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_1", new_rows)
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
)
|
||||
spans = self._vlm_field(self._video_message(record, prompt), "subtasks")
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
return cleaned
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None,
|
||||
task: str | None = None,
|
||||
) -> str | None:
|
||||
if not subtask_spans:
|
||||
return None
|
||||
subtasks_text = "\n".join(f"- {s['text']}" for s in subtask_spans)
|
||||
prompt = load_prompt("module_1_plan").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
subtasks_text=subtasks_text,
|
||||
plan_max_steps=self.config.plan_max_steps,
|
||||
)
|
||||
if refresh_t is not None:
|
||||
# ``current_subtask`` is the span the refresh time falls into,
|
||||
# so the model knows where in the demonstration the planner is
|
||||
# standing when it re-emits.
|
||||
current_subtask = ""
|
||||
for span in subtask_spans:
|
||||
if float(span["start"]) <= refresh_t and (
|
||||
"end" not in span or float(span["end"]) > refresh_t
|
||||
):
|
||||
current_subtask = span.get("text", "")
|
||||
break
|
||||
if interjection:
|
||||
prompt += (
|
||||
f"\n\n(Plan refresh at t={refresh_t:.2f}s after a user "
|
||||
f"interjection: {interjection!r}. Current subtask just "
|
||||
f"before the interjection: {current_subtask!r}. Update "
|
||||
f"the plan so it reflects the interjection — drop or "
|
||||
f"reorder steps as needed; do not just restate.)\n"
|
||||
)
|
||||
else:
|
||||
# Refresh without an interjection text: still tell the model
|
||||
# where in the episode the plan stands so the re-emission
|
||||
# is grounded. Should be rare — plan refreshes are
|
||||
# interjection-driven by design.
|
||||
prompt += f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current subtask: {current_subtask!r}.)\n"
|
||||
plan = self._vlm_field(self._text_message(prompt), "plan")
|
||||
return plan.strip() if isinstance(plan, str) else None
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("module_1_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
memory = self._vlm_field(self._text_message(prompt), "memory")
|
||||
return memory.strip() if isinstance(memory, str) else ""
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -0,0 +1,25 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Concrete example from MEM:
|
||||
Before: "I put a light green bowl, a dark blue bowl and a bright yellow
|
||||
bowl into the top right cabinet"
|
||||
After: "I placed three bowls in the top right cabinet"
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Update the memory. Drop irrelevant detail. Compress completed steps.
|
||||
Keep WHAT happened, drop HOW. Shorter is better.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<one or two short sentences>" }}
|
||||
@@ -0,0 +1,18 @@
|
||||
You are the high-level planner for a robot demonstrating: "{episode_task}".
|
||||
|
||||
Given the subtask decomposition below, write a concise hierarchical PLAN
|
||||
the robot should follow. Format the plan as a numbered list, one line per
|
||||
high-level step. The plan describes the full task; subtasks are the atomic
|
||||
skills used to execute it.
|
||||
|
||||
Subtasks for context:
|
||||
{subtasks_text}
|
||||
|
||||
Authoring rules:
|
||||
- 3 to {plan_max_steps} steps.
|
||||
- Each step describes one logical chunk of the task, not one motion.
|
||||
- Steps must be in execution order.
|
||||
- Plain prose, no JSON, no markdown headers.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "plan": "1. ...\n2. ...\n3. ..." }}
|
||||
@@ -0,0 +1,33 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs.
|
||||
|
||||
Authoring rules — based on Hi Robot (Shi 2025) atom granularity and
|
||||
Pi0.7 (Physical Intelligence 2025) "how, not what" detail:
|
||||
|
||||
- Each subtask is one atomic skill the low-level policy can execute,
|
||||
e.g. "pick up one piece of lettuce", "place the bowl into the box",
|
||||
"move the right arm to the left".
|
||||
- Capture HOW the subtask is performed, not only WHAT — e.g. prefer
|
||||
"grasp the handle of the sponge with the left hand" to "pick up the
|
||||
sponge".
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds.
|
||||
- Do not exceed {max_steps} subtasks total.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<how-not-what>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (short imperative vs longer polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a single short sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,17 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -0,0 +1,10 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: confident, friendly, single short sentence.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -0,0 +1,46 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE interjection the user would naturally say at this moment to
|
||||
prompt / confirm / encourage the robot to do "{next_subtask}". Phrase it
|
||||
like a real human mid-task remark — conversational, varied, sometimes
|
||||
just a nudge, sometimes a clarification, sometimes a small constraint
|
||||
that the upcoming motion happens to satisfy. Plus the robot's verbal
|
||||
acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One sentence each. Conversational, not robotic.
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<single sentence the user says, asking for the next subtask>",
|
||||
"speech": "<single sentence the robot speaks back, confirming and starting>"
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -0,0 +1,268 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
# Memoized parquet slice — populated on first ``frames_df()`` call so
|
||||
# repeat queries from different modules don't re-read the whole shard.
|
||||
_frames_df_cache: Any = field(default=None, init=False, repr=False, compare=False)
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode (memoized)."""
|
||||
if self._frames_df_cache is None:
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
self._frames_df_cache = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(
|
||||
drop=True
|
||||
)
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by Module 1 (plan-update pass) and Module 2 (interjection
|
||||
anchoring), which both need the same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
Modules use this when emitting event-style rows so the row's
|
||||
timestamp matches a real parquet frame (event rows must land on an
|
||||
exact frame, see PR 1's "exact event matching" rule).
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
tasks_path = root / DEFAULT_TASKS_PATH
|
||||
if not tasks_path.exists():
|
||||
return {}
|
||||
table = pq.read_table(tasks_path)
|
||||
cols = {name: table.column(name).to_pylist() for name in table.column_names}
|
||||
if "task_index" in cols and "task" in cols:
|
||||
return dict(zip(cols["task_index"], cols["task"], strict=True))
|
||||
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
|
||||
|
||||
def gather_data_paths(root: Path) -> list[Path]:
|
||||
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||
return sorted((root / "data").rglob("*.parquet"))
|
||||
|
||||
|
||||
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||
table = pq.read_table(path, columns=["episode_index"])
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
out: dict[int, tuple[int, int]] = {}
|
||||
cur_ep: int | None = None
|
||||
start = 0
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start = i
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
out[cur_ep] = (start, i - start)
|
||||
cur_ep = ep
|
||||
start = i
|
||||
if cur_ep is not None:
|
||||
out[cur_ep] = (start, len(episode_col) - start)
|
||||
return out
|
||||
|
||||
|
||||
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||
n = record.row_count
|
||||
if k <= 0 or n == 0:
|
||||
return []
|
||||
if k >= n:
|
||||
return list(range(n))
|
||||
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||
|
||||
|
||||
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||
for path in gather_data_paths(root):
|
||||
offsets = episode_offsets_per_path(path)
|
||||
if episode_index in offsets:
|
||||
start, count = offsets[episode_index]
|
||||
return path, start, count
|
||||
return None
|
||||
|
||||
|
||||
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||
found = lookup_data_path(root, episode_index)
|
||||
if found is None:
|
||||
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||
path, start, count = found
|
||||
table = pq.read_table(path, columns=["timestamp"])
|
||||
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||
return path, [float(t) for t in timestamps]
|
||||
@@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"module_1",
|
||||
"module_2",
|
||||
"module_3",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
|
||||
|
||||
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||
"""Yield episode indices for which any staging artifact exists."""
|
||||
if not root.exists():
|
||||
return
|
||||
for child in sorted(root.iterdir()):
|
||||
if child.is_dir() and child.name.startswith("episode_"):
|
||||
try:
|
||||
yield int(child.name.removeprefix("episode_"))
|
||||
except ValueError:
|
||||
continue
|
||||
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after Modules 1–3 have all written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -0,0 +1,723 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||
backend supports it),
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
# Note: ``escape`` is always False here — the ``if escape`` branch
|
||||
# above already handled and reset it.
|
||||
if ch == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client per the configured backend.
|
||||
|
||||
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||
obvious.
|
||||
"""
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend == "vllm":
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
import transformers # type: ignore[import-not-found]
|
||||
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||
auto_cls = getattr(transformers, "AutoModelForImageTextToText", None) or getattr(
|
||||
transformers, "AutoModelForVision2Seq", None
|
||||
)
|
||||
if auto_cls is None:
|
||||
raise ImportError(
|
||||
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(config.model_id, trust_remote_code=config.trust_remote_code)
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tok,
|
||||
temperature=temp,
|
||||
do_sample=temp > 0.0,
|
||||
)
|
||||
decoded = processor.batch_decode(
|
||||
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||
)[0]
|
||||
outs.append(decoded)
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get("LEROBOT_OPENAI_SEND_MM_KWARGS", "").lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(messages: Sequence[dict[str, Any]], max_tok: int, temp: float) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [pool.submit(_one_call, messages, max_tok, temp) for messages in batch]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import os as _os # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = _os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd.replace("{port}", str(port)) if "{port}" in base_cmd else f"{base_cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s")
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
# ``api_base`` is the user-configured local-server URL we just spawned
|
||||
# or the user passed in via ``--vlm.api_base``; the bandit B310 warning
|
||||
# is for arbitrary user-controlled URLs with file:/ schemes which
|
||||
# cannot reach this code path.
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp: # noqa: S310 # nosec B310
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(f"[server] did not become ready within {config.serve_ready_timeout_s}s")
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append({"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}})
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
import io # noqa: PLC0415
|
||||
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
vllm exposes its own multimodal entry points that vary by version; for the
|
||||
base flow we simply forward the raw message list and let the caller's
|
||||
custom backend handle templating. Real deployments override this.
|
||||
"""
|
||||
return list(messages)
|
||||
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct (PR 1) for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants moved to lerobot.datasets.language in PR 1 — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — see PR 1's
|
||||
# `tests/datasets/test_language.py` which exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_rows_for_writer(
|
||||
rows: Iterable[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Helper used by tests/validators to partition a flat row list into
|
||||
(persistent_rows, event_rows) using ``column_for_style``.
|
||||
"""
|
||||
persistent: list[dict[str, Any]] = []
|
||||
events: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
return persistent, events
|
||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||
else:
|
||||
trainable_tag = f"policy:{cfg.policy.type}"
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
trainable_tag,
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
|
||||
@@ -21,8 +21,10 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -39,9 +41,13 @@ __all__ = [
|
||||
"PolicyFeature",
|
||||
"RTCAttentionSchedule",
|
||||
# Config classes
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str = ""
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str = ""
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
|
||||
def stamp_repo_id(self) -> None:
|
||||
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||
|
||||
Must be called explicitly at dataset *creation* time — not on resume,
|
||||
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||
"""
|
||||
if self.repo_id:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
# ``stream`` is typed Optional only so the dataclass can keep its
|
||||
# field ordering, but recipes must always tag every turn with a
|
||||
# stream — the renderer's ``_validate_rendered`` would reject
|
||||
# ``None`` later on. Fail at construction so the bad recipe is
|
||||
# caught at YAML load time rather than at the first sample.
|
||||
if self.stream is None:
|
||||
raise ValueError(
|
||||
f"MessageTurn(role={self.role!r}) is missing a stream — "
|
||||
f"every turn must declare one of {sorted(_VALID_STREAMS)}."
|
||||
)
|
||||
if self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -0,0 +1,163 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Base configuration for reward models.
|
||||
|
||||
Args:
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
"""
|
||||
|
||||
# Reuses PolicyFeature
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
# Hub metadata
|
||||
license: str | None = None
|
||||
tags: list[str] | None = None
|
||||
private: bool | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**reward_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
# HACK: Parse the original config to get the config subclass, so that we can
|
||||
# apply cli overrides.
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config.pop("type", None)
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
config_file = f.name
|
||||
|
||||
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -26,18 +28,57 @@ from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
reward_model: RewardModelConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
@@ -72,27 +113,41 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def is_reward_model_training(self) -> bool:
|
||||
"""True when the config targets a reward model rather than a policy."""
|
||||
return self.reward_model is not None
|
||||
|
||||
@property
|
||||
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||
"""Return whichever config (policy or reward_model) is active."""
|
||||
if self.is_reward_model_training:
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
self.reward_model = RewardModelConfig.from_pretrained(
|
||||
reward_model_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
@@ -108,18 +163,22 @@ class TrainPipelineConfig(HubMixin):
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None:
|
||||
if self.policy is None and self.reward_model is None:
|
||||
raise ValueError(
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
"Neither policy nor reward_model is configured. "
|
||||
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{self.policy.type}"
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
@@ -137,26 +196,16 @@ class TrainPipelineConfig(HubMixin):
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
|
||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
||||
raise ValueError(
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
return ["policy", "reward_model"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
@@ -207,6 +256,15 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@@ -37,6 +37,14 @@ from .dataset_tools import (
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
@@ -53,10 +61,15 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"add_features",
|
||||
@@ -66,6 +79,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
|
||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
|
||||
return df
|
||||
|
||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||
dst_meta.info.total_frames += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -34,16 +34,13 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
@@ -177,7 +174,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -228,7 +224,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
return packaging.version.parse(self.info.codebase_version)
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
"""Return the relative parquet file path for the given episode index.
|
||||
@@ -283,27 +279,27 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
return self.info.data_path
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
return self.info.video_path
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info["robot_type"]
|
||||
return self.info.robot_type
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
return self.info.fps
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info["features"]
|
||||
return self.info.features
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
@@ -320,6 +316,39 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def has_language_columns(self) -> bool:
|
||||
"""Return ``True`` if the dataset declares any language column.
|
||||
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -333,32 +362,32 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
return self.info.total_episodes
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
return self.info.total_frames
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
return self.info.total_tasks
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
return self.info.chunks_size
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
return self.info.data_files_size_in_mb
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
return self.info.video_files_size_in_mb
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
@@ -502,10 +531,10 @@ class LeRobotDatasetMetadata:
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info.total_episodes += 1
|
||||
self.info.total_frames += episode_length
|
||||
self.info.total_tasks = len(self.tasks)
|
||||
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -524,7 +553,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -546,17 +575,17 @@ class LeRobotDatasetMetadata:
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info["chunks_size"] = chunks_size
|
||||
self.info.chunks_size = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
@@ -635,7 +664,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -653,7 +681,7 @@ class LeRobotDatasetMetadata:
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
write_info(obj.info, obj.root)
|
||||
obj.revision = None
|
||||
obj._pq_writer = None
|
||||
obj.latest_episode = None
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -897,14 +897,10 @@ def _copy_and_reindex_episodes_metadata(
|
||||
|
||||
dst_meta.finalize()
|
||||
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episode_mapping),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
||||
}
|
||||
)
|
||||
dst_meta.info.total_episodes = len(episode_mapping)
|
||||
dst_meta.info.total_frames = total_frames
|
||||
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if not all_stats:
|
||||
@@ -1069,21 +1065,20 @@ def _copy_episodes_metadata_and_stats(
|
||||
if episodes_dir.exists():
|
||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
||||
}
|
||||
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||
# Preserve original splits if available, otherwise create default
|
||||
dst_meta.info.splits = (
|
||||
src_dataset.meta.info.splits
|
||||
if src_dataset.meta.info.splits
|
||||
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||
)
|
||||
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
||||
"info", {}
|
||||
)
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
@@ -1525,7 +1520,7 @@ def modify_tasks(
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
@@ -1858,10 +1853,10 @@ def convert_image_to_video_dataset(
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
new_meta.info.total_episodes = len(episode_indices)
|
||||
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
@@ -1870,7 +1865,7 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from pprint import pformat
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||
``{observation,action,reward}_delta_indices`` properties used below.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
|
||||
@@ -22,12 +22,19 @@ from PIL import Image as PILImage
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +52,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -78,8 +91,8 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
) -> DatasetInfo:
|
||||
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
@@ -87,25 +100,24 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
DatasetInfo: A typed dataset information object with initial metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
return DatasetInfo(
|
||||
codebase_version=codebase_version,
|
||||
fps=fps,
|
||||
features=features,
|
||||
robot_type=robot_type,
|
||||
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path=DEFAULT_DATA_PATH,
|
||||
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
)
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return ""
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
@@ -31,14 +31,15 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
DatasetInfo,
|
||||
serialize_dict,
|
||||
)
|
||||
|
||||
@@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
def load_info(local_dir: Path) -> DatasetInfo:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
DatasetInfo: The typed dataset information object.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
raw = load_json(local_dir / INFO_PATH)
|
||||
return DatasetInfo.from_dict(raw)
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
@@ -189,14 +186,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -268,11 +257,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -307,8 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(f"Rows of style {style!r} must have camera=None; got camera={camera!r}.")
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
@@ -0,0 +1,543 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
|
||||
is also a float64 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz)."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
is within ``EMITTED_AT_TOLERANCE_S`` of ``t`` (see that constant for why
|
||||
we use a tolerance instead of bit-equality). For event styles, the
|
||||
``events`` list is assumed to come from the dataset row at frame ``t``
|
||||
(event rows carry no timestamp of their own), so all matching event rows
|
||||
are considered emitted at ``t``. ``camera`` filters by the row's
|
||||
``camera`` field — required to disambiguate when multiple view-dependent
|
||||
rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if abs(_timestamp(row) - t) <= EMITTED_AT_TOLERANCE_S
|
||||
]
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
# ``stream`` is enforced non-None at MessageTurn construction time
|
||||
# (see ``MessageTurn.__post_init__``), so a missing stream here would
|
||||
# mean the dataclass invariant was bypassed; no need to re-check.
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the resolver is ambiguous.
|
||||
|
||||
Multiple matches always raise — even when the caller already passed
|
||||
some selectors — because remaining ambiguity means the data has
|
||||
several rows that look identical to the resolver and the caller
|
||||
needs to pin down a specific one (e.g. add ``camera=...`` for VQA
|
||||
rows shared across cameras).
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r} role={role!r} "
|
||||
f"tool_name={tool_name!r} camera={camera!r}: {len(rows)} matching rows. "
|
||||
f"Add a selector that distinguishes them."
|
||||
)
|
||||
return rows[0]
|
||||
|
||||
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = (
|
||||
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
||||
)
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
value = row["timestamp"]
|
||||
return float(value.item() if hasattr(value, "item") else value)
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -630,6 +630,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a new LeRobotDataset from scratch for recording data.
|
||||
|
||||
@@ -677,6 +679,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj._requested_root = obj.meta.root
|
||||
|
||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
return self._datasets[0].meta.info.fps
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
return len(self._datasets[0].meta.video_keys) > 0
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
|
||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
@@ -83,7 +88,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -94,6 +98,130 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||
|
||||
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||
explicit field definitions, IDE auto-completion, and validation at
|
||||
construction time.
|
||||
"""
|
||||
|
||||
codebase_version: str
|
||||
fps: int
|
||||
features: dict[str, dict]
|
||||
|
||||
# Episode / frame counters — start at zero for new datasets
|
||||
total_episodes: int = 0
|
||||
total_frames: int = 0
|
||||
total_tasks: int = 0
|
||||
|
||||
# Storage settings
|
||||
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
|
||||
# File path templates
|
||||
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||
tools: list[dict] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
# returns lists, but the rest of the codebase expects tuples.
|
||||
for ft in self.features.values():
|
||||
if isinstance(ft.get("shape"), list):
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
if self.chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||
if self.data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||
if self.video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
Drops ``tools`` when unset so existing datasets keep a clean
|
||||
``info.json``.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
if d.get("tools") is None:
|
||||
d.pop("tools", None)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||
|
||||
Unknown keys are ignored for forward compatibility with datasets that
|
||||
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||
logged when such fields are present.
|
||||
"""
|
||||
known = {f.name for f in dataclasses.fields(cls)}
|
||||
unknown = sorted(k for k in data if k not in known)
|
||||
if unknown:
|
||||
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||
return cls(**{k: v for k, v in data.items() if k in known})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Temporary dict-style compatibility layer
|
||||
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||
# Once all callers have been migrated to attribute access, remove these.
|
||||
# ---------------------------------------------------------------------------
|
||||
def __getitem__(self, key: str):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||
f"Use attribute access info.{key} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError as err:
|
||||
raise KeyError(key) from err
|
||||
|
||||
def __setitem__(self, key: str, value) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||
f"Use attribute assignment info.{key} = ... instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not hasattr(self, key):
|
||||
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if a field exists (dict-like interface)."""
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
"""Get attribute value with default fallback (dict-like interface)."""
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
return default
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
@@ -294,7 +422,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
dataset_info: DatasetInfo | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
@@ -305,7 +433,7 @@ def create_lerobot_dataset_card(
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
@@ -318,7 +446,7 @@ def create_lerobot_dataset_card(
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
|
||||
@@ -12,8 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
@@ -21,10 +24,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -42,12 +42,11 @@ __all__ = [
|
||||
"DiffusionConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"RewardClassifierConfig",
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
horizon: int = 64
|
||||
n_action_steps: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
use_separate_rgb_encoder_per_camera: bool = True
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,14 +46,13 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -89,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -132,18 +131,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from .sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from .groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -156,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -173,7 +168,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
"smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -200,14 +195,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -378,14 +373,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
@@ -394,14 +381,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -427,6 +406,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -542,7 +528,7 @@ def make_policy(
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_pretrained_path = str(cfg.pretrained_path)
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
@@ -555,7 +541,9 @@ def make_policy(
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
policy = PeftModel.from_pretrained(
|
||||
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||
)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -109,7 +109,6 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
|
||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -748,16 +747,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1292,8 +1283,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
@@ -227,6 +226,7 @@ class PI0FastPaliGemma(nn.Module):
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
del self.paligemma.model.language_model
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -260,13 +260,15 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
norm = 2048**0.5
|
||||
features = features / norm * norm
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -416,8 +418,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -431,8 +432,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
return fast_emb
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
@@ -665,7 +665,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if t < max_decoding_steps - 1:
|
||||
# embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -770,7 +769,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Embed the single previous token
|
||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
|
||||
del self.layers
|
||||
del self.norm
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
del self.model
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.language_model
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.model
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
|
||||
@@ -19,6 +19,7 @@ from .action_queue import ActionQueue
|
||||
from .configuration_rtc import RTCConfig
|
||||
from .latency_tracker import LatencyTracker
|
||||
from .modeling_rtc import RTCProcessor
|
||||
from .relative import reanchor_relative_rtc_prefix
|
||||
|
||||
__all__ = [
|
||||
"ActionInterpolator",
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"LatencyTracker",
|
||||
"RTCConfig",
|
||||
"RTCProcessor",
|
||||
"reanchor_relative_rtc_prefix",
|
||||
]
|
||||
|
||||
@@ -1,116 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
__all__ = ["ActionInterpolator"]
|
||||
|
||||
@@ -92,10 +92,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
return len(self.queue) - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
@@ -103,11 +103,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
return len(self.queue) - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
@@ -115,7 +114,8 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
with self.lock:
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
@@ -35,7 +35,7 @@ class RTCConfig:
|
||||
"""
|
||||
|
||||
# Infrastructure
|
||||
enabled: bool = False
|
||||
enabled: bool = True
|
||||
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Relative-action helpers for Real-Time Chunking (RTC)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
|
||||
|
||||
def reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user