mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 07:39:53 +00:00
Compare commits
126 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3a20ea337e | |||
| b6fb536460 | |||
| bfd3bb1791 | |||
| 4908433f9a | |||
| 6ce1f36002 | |||
| 731576be80 | |||
| 47fb8318b1 | |||
| 53172873e3 | |||
| fcdae0ce8e | |||
| 4852b9f952 | |||
| 0410705aff | |||
| 398a8cf730 | |||
| ab5c1dc392 | |||
| 1292304c42 | |||
| b95eebff77 | |||
| fbcac95662 | |||
| b9db4d21a2 | |||
| aecb80a9d2 | |||
| c98c695127 | |||
| d528078aca | |||
| a648da0455 | |||
| d866c2c9fd | |||
| 01e2228b24 | |||
| c36de3a3e8 | |||
| cbfaf2c544 | |||
| d0278ea093 | |||
| 15f6b08b0e | |||
| fc715db4a3 | |||
| fe4bd2b6ba | |||
| 3f7436ff8a | |||
| 992d13d4e9 | |||
| afe40a016b | |||
| 41095e3cc3 | |||
| e0fa957569 | |||
| c661d81409 | |||
| 33a4b4a5a0 | |||
| a764c3e1d6 | |||
| b416f287f2 | |||
| aa749d4947 | |||
| 1394a6ab5d | |||
| db9118f16f | |||
| 7a945d7bdc | |||
| a47e535b02 | |||
| 6d9b431b54 | |||
| 347e706326 | |||
| fa8ae1e89b | |||
| 3ff6c6860e | |||
| fd89efb545 | |||
| 2776b57c9e | |||
| 0fb5f04965 | |||
| 7296ac97af | |||
| 9cbbcfb6a2 | |||
| fea41b29f5 | |||
| 7b4d281ef5 | |||
| 29bb8bb20e | |||
| 3fe686ce9f | |||
| a1b8134ef1 | |||
| 5f7c6ba61d | |||
| 223cc8a9e2 | |||
| af6d8ebd5b | |||
| 37b1eb218a | |||
| 52e1fd35cb | |||
| 7459dfccb6 | |||
| 73740ecf4b | |||
| 1b81e49214 | |||
| d813c75b76 | |||
| 3434d2ef22 | |||
| b71e10da6b | |||
| 0f6e3230df | |||
| 2f2e42c4aa | |||
| 5ee0104739 | |||
| e064cfcb04 | |||
| b3d9494831 | |||
| 1217fdb6f0 | |||
| d0388e1142 | |||
| 524aa59faa | |||
| 27f7829b09 | |||
| 7f8bf108e8 | |||
| 855ff027f8 | |||
| 3b797bb118 | |||
| aea04721ae | |||
| ab5479129a | |||
| e6d4ac6f02 | |||
| 5722d365c5 | |||
| 3d7e60cee4 | |||
| 7b767d4d60 | |||
| f1e3ab7794 | |||
| 585341ba9f | |||
| 23ff346027 | |||
| 3c5cbe7af4 | |||
| f2cbd97635 | |||
| c06c8d594a | |||
| cd495a3a9d | |||
| c99ac45cd1 | |||
| 13aaafeae0 | |||
| 2129648bf4 | |||
| f5cd3f6e4e | |||
| ecf5766301 | |||
| 11597d4f71 | |||
| 8b9c598cf4 | |||
| b325475b38 | |||
| ef137ff86a | |||
| c5df821a96 | |||
| 7ec3d7999c | |||
| 712d63abbd | |||
| 6653999983 | |||
| 4bdbedc9a0 | |||
| e240305e8e | |||
| ccd189b264 | |||
| ef1242bbd4 | |||
| ebf4a04d41 | |||
| 4419b4ef1b | |||
| ff06ca82d2 | |||
| fcb01e73eb | |||
| 268f8d1f53 | |||
| 663fff0ae2 | |||
| 9d6af804bf | |||
| f763f85213 | |||
| e3e9374e2c | |||
| c1a0c601e2 | |||
| 1ca38d9748 | |||
| 5a6aa64570 | |||
| 0b06790da0 | |||
| b43dc39ba4 | |||
| 2b71221194 | |||
| 8833d735a1 |
@@ -178,3 +178,9 @@ test-smolvla-ete-eval:
|
||||
--env.episode_length=5 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1
|
||||
|
||||
# E2E annotation pipeline smoke test against a tiny in-memory fixture
|
||||
# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM
|
||||
# backend, so it does not require a real model checkpoint or GPU.
|
||||
annotation-e2e:
|
||||
uv run python -m tests.annotations.run_e2e_smoke
|
||||
|
||||
@@ -31,8 +31,12 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: language_and_recipes
|
||||
title: Language Columns and Recipes
|
||||
- local: tools
|
||||
title: Tools
|
||||
- local: annotation_pipeline
|
||||
title: Annotation Pipeline
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
# Annotation Pipeline
|
||||
|
||||
`lerobot-annotate` populates the two language columns introduced by the
|
||||
[Language Columns and Recipes](./language_and_recipes) page —
|
||||
`language_persistent` and `language_events` — directly into
|
||||
`data/chunk-*/file-*.parquet`. There is no flavor namespace and no sidecar
|
||||
file tree: multiple revisions of a dataset mean multiple dataset copies.
|
||||
|
||||
## What the pipeline produces
|
||||
|
||||
Three modules write into a per-episode staging tree, then a single writer
|
||||
rewrites the data shards in place:
|
||||
|
||||
| Style / atom | Column | Module |
|
||||
| ------------------------------------------- | --------------------- | -------- |
|
||||
| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | Module 1 |
|
||||
| `plan` (initial + refresh on interjection) | `language_persistent` | Module 1 |
|
||||
| `memory` (MEM-style compression) | `language_persistent` | Module 1 |
|
||||
| `interjection` | `language_events` | Module 2 |
|
||||
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
|
||||
| `vqa` (user / assistant pair) | `language_events` | Module 3 |
|
||||
|
||||
The writer drops the legacy `subtask_index` column. It does **not** add a
|
||||
`tools` column to the parquet — the tool catalog lives at
|
||||
`meta/info.json["tools"]` instead (see [Tools](./tools)). After every
|
||||
annotation run the pipeline ensures the canonical `say` schema is
|
||||
present in that list, preserving any tools the user pre-declared. Chat-
|
||||
template consumers read the catalog through
|
||||
`LeRobotDatasetMetadata.tools` and pass it to
|
||||
`apply_chat_template(messages, tools=meta.tools, ...)`.
|
||||
|
||||
If you want to declare additional tools for a dataset before annotation
|
||||
runs, edit `meta/info.json["tools"]` directly — the pipeline preserves
|
||||
anything already there. Implementations of those tools live under
|
||||
`src/lerobot/tools/`; one file per tool, registered via
|
||||
`TOOL_REGISTRY`. See the [Tools](./tools) doc for the authoring guide.
|
||||
|
||||
## How to run it locally or on SLURM
|
||||
|
||||
Install the extra and invoke the console script:
|
||||
|
||||
```bash
|
||||
uv sync --extra annotations
|
||||
uv run lerobot-annotate \
|
||||
--repo_id=imstevenpmwork/super_poulain_draft \
|
||||
--vlm.backend=vllm \
|
||||
--vlm.model_id=Qwen/Qwen3.6-27B-FP8 \
|
||||
--vlm.tensor_parallel_size=2
|
||||
```
|
||||
|
||||
The pipeline attaches actual camera footage to every Module 1/2/3 prompt
|
||||
by default, decoded from the dataset's first `observation.images.*`
|
||||
stream. Override with `--vlm.camera_key=observation.images.<name>` to
|
||||
pin a specific viewpoint. Datasets with no video tracks fall back to
|
||||
text-only prompts automatically.
|
||||
|
||||
**Module 1 sees the whole episode as one video block.** Subtask
|
||||
decomposition gets a `{"type":"video", "video":[<frames>]}` block
|
||||
covering the entire demonstration; Qwen-VL pools temporally on its own
|
||||
and decides where to cut. There is no keyframe stride or count knob —
|
||||
`--module_1.max_video_frames` (default 32) only caps the frames packed
|
||||
into the video block as a model-capacity bound. Module 2 attaches a
|
||||
single still frame at the interjection timestamp; Module 3 attaches the
|
||||
exact emission frame to each VQA pair.
|
||||
|
||||
The executor picks `LocalPipelineExecutor` for small datasets and
|
||||
`SlurmPipelineExecutor` for large ones based on
|
||||
`--executor.auto_threshold` (default 32 episodes). Force local with
|
||||
`--executor.force_local=true`. SLURM jobs honour `--executor.slurm_partition`,
|
||||
`--executor.slurm_gpus`, and `--executor.slurm_time`.
|
||||
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline produces exactly the styles consumed by
|
||||
`src/lerobot/configs/recipes/pi05_hirobot.yaml`:
|
||||
|
||||
- `low_level_execution`, `high_level_subtask`, `memory_update` consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
- `user_interjection_response` consumes `interjection` events plus the
|
||||
paired speech atom (merged into one assistant target turn via
|
||||
`tool_calls_from`) and the same-timestamp `plan` refresh.
|
||||
- `ask_vqa` consumes the `(vqa, user)` and `(vqa, assistant)` pairs from
|
||||
`language_events`.
|
||||
|
||||
## Why the design is scoped to the canonical recipe
|
||||
|
||||
Two things drive the scope:
|
||||
|
||||
1. **Persistent state vs exact-event split.** Persistent rows (`subtask`,
|
||||
`plan`, `memory`) broadcast per episode and answer "what state is in
|
||||
force at this frame?". Event rows (`interjection`, `vqa`, speech) only
|
||||
appear on the exact frame whose timestamp matches the emission. The
|
||||
pipeline writes timestamps taken straight from the source parquet — no
|
||||
floating-point recomputation.
|
||||
2. **One Qwen-VL pass.** All three modules share a single VLM client
|
||||
(vLLM if available, transformers fallback) so the cost is one model
|
||||
load per dataset, not three.
|
||||
|
||||
## Module independence and staged reruns
|
||||
|
||||
Each module writes its raw output to
|
||||
`<root>/.annotate_staging/episode_{N:06d}/<module>.jsonl`. That makes
|
||||
prompt iteration cheap — re-running one module overwrites only its own
|
||||
JSONL file before the writer composes the final parquet. Modules can be
|
||||
disabled via `--module_1.enabled=false` (and similarly for 2 and 3) to
|
||||
test them in isolation.
|
||||
|
||||
## Validation/report checks before final write
|
||||
|
||||
Before the writer runs, `StagingValidator` checks:
|
||||
|
||||
- exact frame-timestamp alignment for every event row;
|
||||
- no orphan speech / interjection pairs;
|
||||
- `plan` is refreshed at every interjection timestamp;
|
||||
- `memory` rows fall on subtask boundaries (warning, not error);
|
||||
- VQA assistant `content` parses as JSON in one of the
|
||||
bbox / keypoint / count / attribute / spatial shapes;
|
||||
- every row routes to the column dictated by `column_for_style(style)`.
|
||||
|
||||
Errors abort the writer (`--skip_validation=true` overrides for debugging).
|
||||
|
||||
## Paper inspirations per module
|
||||
|
||||
- **Module 1 — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417))
|
||||
atom granularity ("pick up one piece of lettuce", "place bowl to box");
|
||||
Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) "how, not
|
||||
what" detail.
|
||||
- **Module 1 — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596))
|
||||
compression directive: keep only minimal relevant information; functional
|
||||
outcomes preserved, specific attributes dropped.
|
||||
- **Module 2 — interjections.** Hi Robot scenario taxonomy: negative task,
|
||||
situated correction, specific constraint, preference. Speech is a
|
||||
tool-call-only atom (`tool_calls=[{type:function, function:{name:"say",
|
||||
arguments:{text:...}}}]`).
|
||||
- **Module 3 — VQA.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693))
|
||||
grounded features (bounding boxes in pixel `[x_min, y_min, x_max, y_max]`,
|
||||
keypoints) and Steerable Policies' multi-abstraction grounding.
|
||||
|
||||
Future maintainers should adjust the prompt templates in
|
||||
`src/lerobot/annotations/steerable_pipeline/prompts/` against these
|
||||
references rather than rewriting from scratch.
|
||||
|
||||
## Compute and list-size estimates
|
||||
|
||||
Per episode, the pipeline issues O(`max_steps`) Module 1 calls,
|
||||
O(`max_interjections_per_episode`) Module 2 calls, and
|
||||
O(`vqa_emission_hz × episode_seconds`) Module 3 calls. With defaults
|
||||
(8 subtasks, 1 interjection, 1 Hz × 3 pairs) and 30-second episodes, that
|
||||
is ~50 VLM calls per episode. `language_persistent` per episode is ~10s of
|
||||
KB at most (parquet dictionary-encodes one entry per episode);
|
||||
`language_events` is empty on most frames and is bounded by the number of
|
||||
emissions, not `num_frames × num_emissions`.
|
||||
|
||||
## Reproducibility via seed and prompt hashes
|
||||
|
||||
`--seed` (default 1729) feeds the per-episode RNGs that select interjection
|
||||
timestamps and VQA question types. Combined with the deterministic prompt
|
||||
templates checked into `prompts/`, two runs at the same seed against the
|
||||
same dataset and the same model checkpoint produce byte-identical staging
|
||||
artifacts. Prompt edits are recorded by file hash; future tooling can pin
|
||||
expected `(seed, prompt_hash)` pairs into the dataset card.
|
||||
@@ -1,277 +0,0 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor import TokenizerProcessorStep
|
||||
|
||||
# Create a tokenizer processor step
|
||||
tokenizer_processor = TokenizerProcessorStep(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -0,0 +1,109 @@
|
||||
# Language columns and recipes
|
||||
|
||||
LeRobot stores reusable language annotations directly next to frame data in `data/chunk-*/file-*.parquet`.
|
||||
The two optional columns are:
|
||||
|
||||
- `language_persistent`: a list of rows broadcast across every frame in an episode for state that remains active, such as `subtask`, `plan`, and `memory`.
|
||||
- `language_events`: a list of rows only on the exact frame where an event was emitted, such as `interjection`, `vqa`, and speech tool calls.
|
||||
|
||||
Both columns share the same row shape (event rows omit `timestamp` because the
|
||||
frame the row sits on already provides it):
|
||||
|
||||
```text
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float64 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
|
||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||
view. Rows of view-dependent styles (`vqa`, and the reserved `motion` /
|
||||
`trace`) MUST set `camera` to the matching `observation.images.*` feature key.
|
||||
Rows of every other style MUST leave `camera` as `null`. Pipeline writers and
|
||||
the validator enforce this via `validate_camera_field(style, camera)`.
|
||||
|
||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||
|
||||
## Architecture
|
||||
|
||||
The language stack has three layers:
|
||||
|
||||
1. `lerobot.datasets.language` defines the schema, style registry, and `column_for_style`.
|
||||
2. `lerobot.datasets.language_render` resolves rows and renders messages.
|
||||
3. `RenderMessagesStep` turns dataset samples into `messages`, `message_streams`, and `target_message_indices`.
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
|
||||
- `active_at(t, style=subtask)`
|
||||
- `nth_prev(style=memory, offset=1)`
|
||||
- `nth_next(style=subtask, offset=1)`
|
||||
|
||||
Event styles only exist on their exact timestamp:
|
||||
|
||||
- `emitted_at(t, style=interjection)`
|
||||
- `emitted_at(t, style=vqa, role=user, camera=observation.images.top)`
|
||||
- `emitted_at(t, role=assistant, tool_name=say)`
|
||||
|
||||
Exact event matching has no tolerance window, so writers must stamp event rows with frame timestamps from the parquet data.
|
||||
|
||||
## View-dependent resolution
|
||||
|
||||
For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a
|
||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||
matches and raise an ambiguity error. Recipes consume each camera through its
|
||||
own binding plus a matching image block, e.g.
|
||||
|
||||
```yaml
|
||||
ask_vqa_top:
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- { type: image, feature: observation.images.top }
|
||||
- { type: text, text: "${vqa_query}" }
|
||||
- { role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa }
|
||||
```
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`.
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
```python
|
||||
sample["messages"]
|
||||
sample["message_streams"]
|
||||
sample["target_message_indices"]
|
||||
```
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone.
|
||||
|
||||
## Blends
|
||||
|
||||
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||
The canonical `recipes/pi05_hirobot.yaml` combines memory updates, interjection responses, high-level subtask prediction, low-level execution, and VQA.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
If an event-scoped branch is selected on a frame without the required event row, rendering returns `None`, allowing a loader to retry another sample.
|
||||
@@ -0,0 +1,198 @@
|
||||
# Tools
|
||||
|
||||
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||
emit structured invocations like `say(text="OK, starting now")` that the
|
||||
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||
|
||||
This page covers:
|
||||
|
||||
1. Where the tool catalog lives (PR 1).
|
||||
2. How the annotation pipeline produces tool-call atoms (PR 2).
|
||||
3. How to add your own tool (PR 3).
|
||||
|
||||
## Where tools are declared
|
||||
|
||||
Two layers.
|
||||
|
||||
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||
`meta/info.json["tools"]` on each dataset. Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"features": { "...": "..." },
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": { "type": "string", "description": "The verbatim text to speak." }
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Read it via the dataset metadata accessor:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||
```
|
||||
|
||||
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||
single-entry list with the canonical `say` schema. So unannotated
|
||||
datasets and chat-template consumers keep working without any
|
||||
configuration:
|
||||
|
||||
```python
|
||||
prompt_str = tokenizer.apply_chat_template(
|
||||
sample["messages"],
|
||||
tools=meta.tools, # works either way
|
||||
add_generation_prompt=False,
|
||||
tokenize=False,
|
||||
)
|
||||
```
|
||||
|
||||
**The implementations** — runnable Python — live under
|
||||
`src/lerobot/tools/`, one file per tool. The `say` implementation
|
||||
arrives in PR 3 and wraps Kyutai's pocket-tts model.
|
||||
|
||||
## Per-row tool *invocations*
|
||||
|
||||
The catalog above describes *what can be called*. The actual *call* — the
|
||||
function name plus the argument values — is stored per-row, on the
|
||||
assistant atoms in `language_events`:
|
||||
|
||||
```python
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"style": null,
|
||||
"timestamp": 12.4,
|
||||
"camera": null,
|
||||
"tool_calls": [
|
||||
{ "type": "function",
|
||||
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||
|
||||
```yaml
|
||||
user_interjection_response:
|
||||
bindings:
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${current_plan}", stream: high_level,
|
||||
target: true, tool_calls_from: speech }
|
||||
```
|
||||
|
||||
The model's training target is one assistant turn that carries both the
|
||||
plan text *and* the `say` tool call. At inference, the runtime parses
|
||||
the generated text back into structured `tool_calls` and dispatches to
|
||||
the matching implementation.
|
||||
|
||||
## How to add your own tool
|
||||
|
||||
Three steps. Concrete example: a `record_observation` tool the policy
|
||||
can call to capture an extra observation outside the regular control
|
||||
loop.
|
||||
|
||||
### Step 1 — declare the schema
|
||||
|
||||
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||
directly on disk *before* running the annotation pipeline (it'll be
|
||||
preserved) or hand it to `lerobot-annotate` via a config flag (PR 2 —
|
||||
exact CLI lands with the pipeline change).
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a high-resolution still image for the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"label": { "type": "string", "description": "Short label for the saved image." }
|
||||
},
|
||||
"required": ["label"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The schema follows OpenAI's function-calling convention exactly, so the
|
||||
chat template can render it natively.
|
||||
|
||||
### Step 2 — implement the call
|
||||
|
||||
Create `src/lerobot/tools/record_observation.py`:
|
||||
|
||||
```python
|
||||
from .base import Tool
|
||||
from typing import Any
|
||||
|
||||
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||
|
||||
|
||||
class RecordObservationTool:
|
||||
name = "record_observation"
|
||||
schema = RECORD_OBSERVATION_SCHEMA
|
||||
|
||||
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def call(self, arguments: dict) -> str:
|
||||
label = arguments["label"]
|
||||
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||
return f"saved {label}.png"
|
||||
```
|
||||
|
||||
One file per tool keeps dependencies isolated — `record_observation`
|
||||
might pull `pillow`, while `say` (PR 3) pulls `pocket-tts`. Users
|
||||
installing only the tools they need avoid heavy transitive deps.
|
||||
|
||||
### Step 3 — register it
|
||||
|
||||
Add to `src/lerobot/tools/registry.py` (PR 3):
|
||||
|
||||
```python
|
||||
from .record_observation import RecordObservationTool
|
||||
|
||||
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||
```
|
||||
|
||||
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||
`meta.tools`, instantiates the matching registered class, and returns
|
||||
a name → instance dict the dispatcher can route into.
|
||||
|
||||
## Where this fits in the three-PR stack
|
||||
|
||||
| Layer | PR | What lands |
|
||||
|---|---|---|
|
||||
| Catalog storage in `meta/info.json` + `meta.tools` accessor | PR 1 | This page; `SAY_TOOL_SCHEMA`, `DEFAULT_TOOLS` constants in `lerobot.datasets.language`; `LeRobotDatasetMetadata.tools` property |
|
||||
| Annotation pipeline writes `tools` to meta after a run; honors anything users pre-populated | PR 2 | `lerobot-annotate` ensures `meta/info.json["tools"]` includes the canonical `say` and merges any user-declared tools |
|
||||
| Runnable implementations under `src/lerobot/tools/`; runtime dispatcher; `say.py` wired to Kyutai's pocket-tts | PR 3 | One file per tool; `Tool` protocol; `TOOL_REGISTRY`; optional `[tools]` extra in `pyproject.toml` |
|
||||
|
||||
If you want to use a tool *without* writing an implementation (e.g. for
|
||||
training-time chat-template formatting only), step 1 alone is enough —
|
||||
the model still learns to *generate* the call. Steps 2 and 3 are only
|
||||
needed to actually *execute* it at inference.
|
||||
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs Module 1/2/3 across the dataset (per-camera VQA via PR 3471),
|
||||
4. uploads the annotated dataset to ``--push_to_hub``.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotation/run_hf_job.py
|
||||
|
||||
Adjust ``CMD`` below to point at your own dataset / target hub repo.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
# --- Diversity knobs (Pi0.7-style prompt expansion) -----------------------
|
||||
# Bumped roughly 3x across the board to fight memorization on small datasets.
|
||||
# A single dataset trained for many epochs with deterministic atom wording
|
||||
# converges to perfect recall on training prompts but produces JSON-token
|
||||
# garbage at inference for any wording that drifts slightly. More atom
|
||||
# variants per episode + higher sampling temperature widens the training
|
||||
# distribution so the model has to actually use its language head, not
|
||||
# just memorize.
|
||||
#
|
||||
# Pushes to a *new* hub repo (``_tool3``) so the previous annotation pass
|
||||
# (``_tool2``) stays intact — re-train from scratch on the new dataset and
|
||||
# compare loss-curve shapes to verify the diversity bump is doing something.
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
"--vlm.num_gpus=2 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
"--module_1.frames_per_second=1.0 "
|
||||
"--module_1.use_video_url=true "
|
||||
"--module_1.use_video_url_fps=1.0 "
|
||||
"--module_1.derive_task_from_video=always "
|
||||
"--module_1.n_task_rephrasings=30 "
|
||||
"--module_2.max_interjections_per_episode=6 "
|
||||
"--module_3.K=3 "
|
||||
"--module_3.vqa_emission_hz=1.0 "
|
||||
"--push_to_hub=pepijn223/super_poulain_full_tool3"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x2",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
@@ -0,0 +1,80 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=smolvla2-hirobot
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=8
|
||||
|
||||
# SmolVLA2 training on an annotated dataset.
|
||||
#
|
||||
# The high_level_subtask recipe (recipes/smolvla2_hirobot.yaml) was
|
||||
# fixed in PR3 to supervise the LM head with the *current* active
|
||||
# subtask span at every frame, not the next-span target which is
|
||||
# empty on stable phases. With the old recipe the head learned to
|
||||
# emit ``\n`` on every chunk boundary; the new one supervises a
|
||||
# real, scene-grounded string at every frame.
|
||||
#
|
||||
# Two regularisers are still on:
|
||||
#
|
||||
# * --dataset.image_transforms.enable=true: torchvision-v2
|
||||
# ColorJitter + SharpnessJitter + RandomAffine per frame; default
|
||||
# envelope (brightness ±20% etc).
|
||||
# * --policy.{plan,memory,subtask}_dropout_prob: randomly drop the
|
||||
# context messages carrying the named recipe binding so the model
|
||||
# handles missing/stale context. Mirrors Pi0.7 §V.E.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
|
||||
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
|
||||
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
|
||||
|
||||
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
|
||||
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool6}"
|
||||
JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool6}"
|
||||
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
||||
BATCH_SIZE="${BATCH_SIZE:-32}"
|
||||
STEPS="${STEPS:-2000}"
|
||||
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_super_poulain_tool3_${STEPS}_${RUN_ID}}"
|
||||
|
||||
echo "Training smolvla2 on $DATASET"
|
||||
echo " GPUs: $NUM_PROCESSES"
|
||||
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
|
||||
echo " steps: $STEPS"
|
||||
echo " output: $OUTPUT_DIR"
|
||||
echo " augmentation: image_transforms ON, prompt dropout {plan:0.50 memory:0.50 subtask:0.20}"
|
||||
|
||||
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--policy.type=smolvla2 \
|
||||
--policy.recipe_path=recipes/smolvla2_hirobot.yaml \
|
||||
--dataset.repo_id="$DATASET" \
|
||||
--dataset.revision=main \
|
||||
--dataset.video_backend=pyav \
|
||||
--output_dir="$OUTPUT_DIR" \
|
||||
--job_name="$JOB_NAME" \
|
||||
--policy.repo_id="$POLICY_REPO_ID" \
|
||||
--policy.compile_model=false \
|
||||
--policy.device=cuda \
|
||||
--policy.tokenizer_max_length=512 \
|
||||
--steps="$STEPS" \
|
||||
--policy.scheduler_decay_steps="$STEPS" \
|
||||
--batch_size="$BATCH_SIZE" \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=hirobot \
|
||||
--log_freq=100 \
|
||||
--save_freq="$STEPS" \
|
||||
--num_workers=0 \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.image_transforms.max_num_transforms=3 \
|
||||
--dataset.image_transforms.random_order=true \
|
||||
--policy.plan_dropout_prob=0.50 \
|
||||
--policy.memory_dropout_prob=0.50 \
|
||||
--policy.subtask_dropout_prob=0.20
|
||||
+21
-2
@@ -95,7 +95,7 @@ dependencies = [
|
||||
|
||||
# ── Feature-scoped extras ──────────────────────────────────
|
||||
dataset = [
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.7.0,<5.0.0",
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
@@ -200,6 +200,23 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
|
||||
|
||||
# Annotation pipeline (lerobot-annotate). datatrove is mandatory; vllm is
|
||||
# the preferred backend on Linux, with a transformers fallback elsewhere.
|
||||
annotations = [
|
||||
"lerobot[dataset]",
|
||||
"lerobot[transformers-dep]",
|
||||
"datatrove>=0.4.0,<2.0.0",
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=1.0.0,<3.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -289,10 +306,12 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-smolvla2-runtime="lerobot.scripts.lerobot_smolvla2_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Steerable annotation pipeline producing ``language_persistent`` and
|
||||
``language_events`` columns for LeRobot datasets.
|
||||
|
||||
The pipeline is decomposed into three independently runnable modules whose
|
||||
outputs are staged per-episode before a final parquet rewrite:
|
||||
|
||||
- :mod:`.modules.plan_subtasks_memory` (Module 1) — persistent styles
|
||||
- :mod:`.modules.interjections_and_speech` (Module 2) — event styles + speech
|
||||
- :mod:`.modules.general_vqa` (Module 3) — event-style VQA pairs
|
||||
"""
|
||||
|
||||
from .config import AnnotationPipelineConfig
|
||||
from .validator import StagingValidator, ValidationReport
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
__all__ = [
|
||||
"AnnotationPipelineConfig",
|
||||
"LanguageColumnsWriter",
|
||||
"StagingValidator",
|
||||
"ValidationReport",
|
||||
]
|
||||
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module1Config:
|
||||
"""Module 1 hyperparameters: plan + subtasks + memory + task augmentation.
|
||||
|
||||
Subtask decomposition sees the **whole episode** as one Qwen-VL video
|
||||
block — no keyframe stride or count: the model handles temporal pooling
|
||||
itself and decides where to cut. ``max_video_frames`` only caps the
|
||||
number of frames packed into the video block (a model-capacity bound,
|
||||
not an annotation-logic knob).
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
n_task_rephrasings: int = 10
|
||||
"""Number of task rephrasings to generate at ``t=0`` as ``task_aug``
|
||||
persistent rows (PR 1 ``CORE_STYLES``). The renderer's ``${task}``
|
||||
binding rotates among them deterministically per ``sample_idx``,
|
||||
realizing Xiao 2022 / CAST-style task-prompt diversity without
|
||||
touching ``meta/tasks.parquet``. Set to 0 to disable."""
|
||||
derive_task_from_video: str = "if_short"
|
||||
"""When to bypass the user-provided ``record.episode_task`` and
|
||||
derive a fresh task description from the episode video alone:
|
||||
|
||||
- ``off`` never; always use the canonical task as the basis.
|
||||
- ``if_short`` derive when the canonical task is empty, has fewer
|
||||
than ``derive_task_min_words`` words, or matches a
|
||||
placeholder string (``debug``, ``unnamed``, ``tbd``,
|
||||
...). Default — fixes noisy / placeholder tasks
|
||||
without forcing derivation everywhere.
|
||||
- ``always`` ignore the canonical task entirely; always derive
|
||||
from the video. Useful when the dataset's task
|
||||
labels are uniformly bad.
|
||||
|
||||
The video-derived task replaces the canonical task as the basis for
|
||||
subtask decomposition, plan, memory, AND the ``task_aug`` rephrasings,
|
||||
so every downstream annotation is grounded in what's actually visible.
|
||||
``meta/tasks.parquet`` is NOT modified — the Module-1-derived task
|
||||
only lives in ``language_persistent`` rows."""
|
||||
derive_task_min_words: int = 3
|
||||
"""Word-count threshold for ``derive_task_from_video=if_short``."""
|
||||
frames_per_second: float = 1.0
|
||||
"""Sample one image-frame per ``1/fps`` seconds across the episode for
|
||||
Module 1's subtask-decomposition prompt. ``1.0`` = 1 fps. Capped by
|
||||
``max_video_frames`` to avoid blowing up the request payload."""
|
||||
max_video_frames: int = 128
|
||||
"""Hard cap on the number of frames Module 1 sends. With ``fps=1`` and
|
||||
a 30 s episode this yields 30 frames. Bumped from 32 since each frame
|
||||
is small (~30-100 KB PNG when base64'd)."""
|
||||
min_subtask_seconds: float = 1.5
|
||||
plan_max_steps: int = 8
|
||||
use_video_url: bool = False
|
||||
"""When True (and backend supports it, e.g. ``openai``), Module 1
|
||||
sends a ``video_url`` content block pointing at the episode's mp4
|
||||
file instead of pre-decoded frames. Lets the server sample frames at
|
||||
its own ``fps`` — no in-process conv3d cost. The video file is
|
||||
extracted as a per-episode subclip to ``staging/.video_clips/`` so
|
||||
the model sees only this episode's frames."""
|
||||
use_video_url_fps: float = 1.0
|
||||
"""Frame-rate hint to send to the server (mm_processor_kwargs.fps).
|
||||
Only used when ``use_video_url=True``. ``1.0`` = sample 1 frame per
|
||||
second, which is plenty for subtask-boundary detection on most
|
||||
manipulation episodes."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module2Config:
|
||||
"""Module 2 hyperparameters: interjections + paired speech."""
|
||||
|
||||
enabled: bool = True
|
||||
max_interjections_per_episode: int = 3
|
||||
"""Number of mid-episode interjections to generate per episode. Each
|
||||
creates a paired ``(interjection, speech)`` event row plus triggers a
|
||||
``plan`` refresh at the same timestamp via Module 1. Bumped from the
|
||||
original ``1`` after qwen36moe-10 showed plan/interjection coverage
|
||||
was too sparse for Hi Robot-style training."""
|
||||
interjection_min_t: float = 2.0
|
||||
interjection_window_seconds: float = 2.0
|
||||
"""How many seconds of video to attach to the interjection prompt as
|
||||
visual context. Without this the VLM only sees a single frozen frame
|
||||
and writes generic interjections that aren't grounded in the actual
|
||||
motion happening at the chosen timestamp."""
|
||||
interjection_window_frames: int = 4
|
||||
"""How many frames to sample over ``interjection_window_seconds``.
|
||||
Default 4 ⇒ ~0.5 fps over the leading 2 seconds — enough for the
|
||||
model to read the ongoing motion, cheap enough to keep prompt size
|
||||
bounded for the 32k context."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module3Config:
|
||||
"""Module 3 hyperparameters: general VQA."""
|
||||
|
||||
enabled: bool = True
|
||||
vqa_emission_hz: float = 1.0
|
||||
K: int = 3
|
||||
question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VlmConfig:
|
||||
"""Shared Qwen-VL client configuration."""
|
||||
|
||||
backend: str = "openai"
|
||||
"""One of ``vllm``, ``transformers``, ``openai``, or ``stub`` (tests only).
|
||||
|
||||
Default ``openai`` talks to a local OpenAI-compatible server (vllm /
|
||||
transformers) which the CLI auto-spawns when ``auto_serve=True``."""
|
||||
model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
api_base: str = "http://localhost:8000/v1"
|
||||
"""Base URL for the ``openai`` backend."""
|
||||
api_key: str = "EMPTY"
|
||||
"""API key for the ``openai`` backend; ``EMPTY`` works for local servers."""
|
||||
auto_serve: bool = True
|
||||
"""When True with ``backend=openai``, the CLI probes ``api_base``
|
||||
first; if no server answers, it spawns one (default:
|
||||
``transformers serve``), waits for it to be ready, runs the
|
||||
pipeline, and tears it down on exit. Default ``True`` so a single
|
||||
``lerobot-annotate`` call can drive the whole flow. Set to ``False``
|
||||
if you want to fail fast when no server is reachable (e.g. you're
|
||||
pointing at a remote endpoint that should already be up)."""
|
||||
serve_port: int = 8000
|
||||
"""Port the auto-spawned server binds to. Sets ``api_base`` automatically."""
|
||||
serve_command: str | None = None
|
||||
"""Override the auto-serve command (full shell command). When ``None``,
|
||||
we run ``transformers serve <model_id> --port <serve_port> --continuous-batching``.
|
||||
|
||||
When ``parallel_servers > 1``, the literal ``{port}`` placeholder in
|
||||
this command (if present) is substituted per-replica."""
|
||||
parallel_servers: int = 1
|
||||
"""When >1, spawn this many independent inference servers (each pinned
|
||||
to a GPU via ``CUDA_VISIBLE_DEVICES`` and listening on
|
||||
``serve_port + i``) and round-robin client requests across them.
|
||||
Useful when DP/TP NCCL setup is broken on the node — single-GPU
|
||||
replicas don't need cross-GPU communication. When
|
||||
``parallel_servers > num_gpus``, replicas are round-robin-assigned
|
||||
to GPUs (e.g. 4 replicas on 2 GPUs → 0,1,0,1)."""
|
||||
num_gpus: int = 0
|
||||
"""How many physical GPUs are available for round-robin replica
|
||||
placement. ``0`` means ``parallel_servers`` (one GPU per replica,
|
||||
backward-compatible default). Set this to ``2`` with
|
||||
``parallel_servers=4`` to pack 2 replicas per GPU."""
|
||||
client_concurrency: int = 16
|
||||
"""Maximum number of in-flight chat requests the client issues in
|
||||
parallel. vllm batches them internally for free, so bumping this
|
||||
typically gives big throughput wins on a single TP=1 server. Set to
|
||||
``1`` for strict serial calls."""
|
||||
serve_ready_timeout_s: float = 600.0
|
||||
"""Max seconds to wait for the server to start serving requests."""
|
||||
max_new_tokens: int = 512
|
||||
temperature: float = 0.2
|
||||
json_mode: bool = True
|
||||
batch_size: int = 4
|
||||
tensor_parallel_size: int = 1
|
||||
gpu_memory_utilization: float = 0.9
|
||||
"""Fraction of GPU memory vllm allocates for weights + KV cache.
|
||||
Lower (e.g. 0.7) when the vision encoder needs cuDNN workspace, or to
|
||||
avoid CUDNN_STATUS_NOT_INITIALIZED on tight VRAM (30B BF16 on 80 GB)."""
|
||||
max_model_len: int | None = None
|
||||
"""Cap context length. ``None`` keeps the model's default; on H100 80 GB
|
||||
a 30B BF16 model often needs ``max_model_len=8192`` or smaller to leave
|
||||
room for KV cache."""
|
||||
trust_remote_code: bool = False
|
||||
"""Pass ``trust_remote_code`` to HF auto-classes. Default ``False`` —
|
||||
only enable for models that actually ship custom code in their repo
|
||||
(rare for first-class VL releases). On Qwen3-VL it triggers an
|
||||
std::bad_alloc post-load even though the official transformers class
|
||||
is sufficient, so leaving this off is safest."""
|
||||
camera_key: str | None = None
|
||||
"""Override the camera stream used for keyframe attachment. ``None`` picks
|
||||
the first ``observation.images.*`` key the dataset declares."""
|
||||
chat_template_kwargs: dict[str, Any] | None = None
|
||||
"""Forwarded as ``extra_body.chat_template_kwargs`` on every chat call.
|
||||
Use this to pass model-specific template flags such as
|
||||
``{"enable_thinking": false}`` for Qwen3.5/Qwen3.6 to suppress the
|
||||
reasoning preamble that otherwise eats the entire ``max_new_tokens``
|
||||
budget before any JSON is emitted."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""Executor selection and SLURM hyperparameters."""
|
||||
|
||||
auto_threshold: int = 32
|
||||
force_local: bool = False
|
||||
slurm_partition: str | None = None
|
||||
slurm_gpus: int = 1
|
||||
slurm_time: str = "06:00:00"
|
||||
workers: int = 1
|
||||
episode_parallelism: int = 16
|
||||
"""Number of episodes processed concurrently within each module phase.
|
||||
Each in-flight episode sends 3–5 dependent VLM calls; bumping this is
|
||||
how you actually saturate ``parallel_servers`` and ``client_concurrency``
|
||||
— without it, the executor loops one episode at a time and the
|
||||
inference servers sit ~90% idle. Set to ``1`` for strict serial
|
||||
execution."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnnotationPipelineConfig:
|
||||
"""Top-level config for ``lerobot-annotate``.
|
||||
|
||||
Mirrors the structure of :class:`lerobot.configs.train.TrainPipelineConfig`:
|
||||
a draccus-parsed dataclass that contains nested per-module sub-configs and
|
||||
leaves the dataset, executor, and VLM choices independently knobbable.
|
||||
|
||||
Output is always in-place: the writer rewrites ``data/chunk-*/file-*.parquet``
|
||||
in place. Multiple revisions of the same dataset live in separate copies.
|
||||
"""
|
||||
|
||||
repo_id: str | None = None
|
||||
root: Path | None = None
|
||||
|
||||
staging_dir: Path | None = None
|
||||
"""If unset, defaults to ``<root>/.annotate_staging/``."""
|
||||
|
||||
seed: int = 1729
|
||||
|
||||
module_1: Module1Config = field(default_factory=Module1Config)
|
||||
module_2: Module2Config = field(default_factory=Module2Config)
|
||||
module_3: Module3Config = field(default_factory=Module3Config)
|
||||
|
||||
vlm: VlmConfig = field(default_factory=VlmConfig)
|
||||
executor: ExecutorConfig = field(default_factory=ExecutorConfig)
|
||||
|
||||
skip_validation: bool = False
|
||||
only_episodes: tuple[int, ...] | None = None
|
||||
|
||||
push_to_hub: str | None = None
|
||||
"""If set, after the pipeline completes, upload the annotated dataset
|
||||
root to the Hugging Face Hub as a dataset repo with this id (e.g.
|
||||
``pepijn/super_poulain_steerable``). Creates the repo if missing."""
|
||||
push_private: bool = False
|
||||
"""When ``push_to_hub`` is set, create the repo as private."""
|
||||
push_commit_message: str | None = None
|
||||
"""Override the commit message used for the hub upload."""
|
||||
|
||||
def resolved_staging_dir(self, root: Path) -> Path:
|
||||
return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging"
|
||||
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Executor selection: local vs SLURM via datatrove.
|
||||
|
||||
The executor plans **four phases** with the dependency order from the plan:
|
||||
|
||||
phase 1: Module 1 (plan + subtasks + memory)
|
||||
phase 2: Module 2 (interjections + speech)
|
||||
phase 3: Module 1 plan-update pass — re-runs plan emission at every
|
||||
interjection timestamp produced by phase 2
|
||||
phase 4: Module 3 (VQA)
|
||||
phase 5: validator
|
||||
phase 6: writer
|
||||
|
||||
Phase 3 is why ``executor.py`` documents the dependency: Module 1 must be
|
||||
re-entered after Module 2 to refresh ``plan`` rows at interjection times.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import AnnotationPipelineConfig, ExecutorConfig
|
||||
from .reader import EpisodeRecord, iter_episodes
|
||||
from .staging import EpisodeStaging
|
||||
from .validator import StagingValidator
|
||||
from .writer import LanguageColumnsWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhaseResult:
|
||||
"""Summary of one pipeline phase across all episodes."""
|
||||
|
||||
name: str
|
||||
episodes_processed: int
|
||||
episodes_skipped: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunSummary:
|
||||
"""Aggregated result returned by :meth:`Executor.run`."""
|
||||
|
||||
phases: list[PhaseResult]
|
||||
written_paths: list[Path]
|
||||
validation_report: Any # ValidationReport, kept Any to avoid import cycle
|
||||
|
||||
|
||||
def select_executor_class(num_episodes: int, config: ExecutorConfig) -> str:
|
||||
"""Return ``"local"`` or ``"slurm"`` based on the threshold.
|
||||
|
||||
The plan's "executor selection threshold" lives in
|
||||
:class:`ExecutorConfig.auto_threshold`. ``force_local`` always wins.
|
||||
"""
|
||||
if config.force_local:
|
||||
return "local"
|
||||
return "local" if num_episodes <= config.auto_threshold else "slurm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Executor:
|
||||
"""Run all four phases over a dataset root.
|
||||
|
||||
The executor is intentionally framework-agnostic: by default it runs the
|
||||
phases inline (suitable for tests, small datasets, and the CLI's
|
||||
``--force-local`` mode). It will optionally hand off to datatrove's
|
||||
:class:`LocalPipelineExecutor` or :class:`SlurmPipelineExecutor` when those
|
||||
are installed and the dataset is large enough to benefit from them.
|
||||
|
||||
Tests construct the executor directly with stub modules.
|
||||
"""
|
||||
|
||||
config: AnnotationPipelineConfig
|
||||
module_1: Any # PlanSubtasksMemoryModule
|
||||
module_2: Any # InterjectionsAndSpeechModule
|
||||
module_3: Any # GeneralVqaModule
|
||||
writer: LanguageColumnsWriter
|
||||
validator: StagingValidator
|
||||
|
||||
def run(self, root: Path) -> PipelineRunSummary:
|
||||
records = list(iter_episodes(root, only_episodes=self.config.only_episodes))
|
||||
n = len(records)
|
||||
if n == 0:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
executor_kind = select_executor_class(n, self.config.executor)
|
||||
print(f"[annotate] {n} episodes total; executor={executor_kind}", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
phases: list[PhaseResult] = []
|
||||
|
||||
# Phase 1: Module 1 (plan + subtasks + memory)
|
||||
phases.append(self._run_module_phase("module_1", records, staging_dir, self.module_1))
|
||||
# Phase 2: Module 2 (interjections + speech). Module 2 reads
|
||||
# Module 1's subtask rows from the same staging tree to ground
|
||||
# the interjection prompt in the correct local subtask.
|
||||
phases.append(self._run_module_phase("module_2", records, staging_dir, self.module_2))
|
||||
# Phase 3: Module 1 plan-update pass at interjection timestamps.
|
||||
phases.append(self._run_plan_update_phase(records, staging_dir))
|
||||
# Phase 4: Module 3 (VQA)
|
||||
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Persist the tool catalog to meta/info.json so chat-template
|
||||
# consumers (PR 3 SmolVLA2 / Pi0.5 / dataset visualizer) can read
|
||||
# it via ``LeRobotDatasetMetadata.tools`` (PR 1). Idempotent and
|
||||
# additive: anything the user pre-populated is preserved; we only
|
||||
# ensure the canonical ``say`` schema is present.
|
||||
self._ensure_tools_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
def _ensure_tools_in_info(self, root: Path) -> None:
|
||||
"""Write ``meta/info.json["tools"]`` if missing the canonical ``say``.
|
||||
|
||||
Reads any user-declared tools already in ``info.json`` and merges
|
||||
the canonical ``SAY_TOOL_SCHEMA`` into the list (deduped by
|
||||
``function.name``). Writes back to disk only if the list
|
||||
changed.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
if not info_path.exists():
|
||||
return
|
||||
try:
|
||||
info = json.loads(info_path.read_text())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[annotate] could not read {info_path}: {exc}", flush=True)
|
||||
return
|
||||
|
||||
existing = info.get("tools")
|
||||
if not isinstance(existing, list):
|
||||
existing = []
|
||||
names = {
|
||||
(t.get("function") or {}).get("name")
|
||||
for t in existing
|
||||
if isinstance(t, dict)
|
||||
}
|
||||
merged = list(existing)
|
||||
if SAY_TOOL_SCHEMA["function"]["name"] not in names:
|
||||
merged.append(SAY_TOOL_SCHEMA)
|
||||
if merged != existing:
|
||||
info["tools"] = merged
|
||||
info_path.write_text(json.dumps(info, indent=2))
|
||||
print(
|
||||
f"[annotate] meta/info.json: tools={[t['function']['name'] for t in merged]}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _run_module_phase(
|
||||
self,
|
||||
name: str,
|
||||
records: list[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
import time as _time # noqa: PLC0415
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed # noqa: PLC0415
|
||||
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
parallelism = max(1, min(self.config.executor.episode_parallelism, n))
|
||||
print(
|
||||
f"[annotate] phase={name} starting on {n} episode(s) "
|
||||
f"(parallelism={parallelism})",
|
||||
flush=True,
|
||||
)
|
||||
t0 = _time.time()
|
||||
|
||||
def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]:
|
||||
i, record = idx_record
|
||||
ep_start = _time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
return i, record.episode_index, _time.time() - ep_start
|
||||
|
||||
processed = 0
|
||||
if parallelism == 1:
|
||||
for i, record in enumerate(records, 1):
|
||||
_, ep_idx, elapsed = _do((i, record))
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} "
|
||||
f"(idx={ep_idx}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=parallelism) as pool:
|
||||
futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)]
|
||||
for fut in as_completed(futures):
|
||||
i, ep_idx, elapsed = fut.result()
|
||||
processed += 1
|
||||
print(
|
||||
f"[annotate] {name} episode {processed}/{n} "
|
||||
f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = _time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
||||
|
||||
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
||||
therefore calls back into Module 1 with the interjection timestamps so
|
||||
Module 1's existing prompt path is reused.
|
||||
"""
|
||||
if not self.module_1.enabled or not self.module_2.enabled:
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update", episodes_processed=0, episodes_skipped=len(records)
|
||||
)
|
||||
processed = 0
|
||||
for record in records:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
interjection_rows = [
|
||||
row
|
||||
for row in staging.read("module_2")
|
||||
if row.get("style") == "interjection"
|
||||
]
|
||||
interjection_times = [float(row["timestamp"]) for row in interjection_rows]
|
||||
interjection_texts = [str(row.get("content") or "") for row in interjection_rows]
|
||||
if interjection_times:
|
||||
self.module_1.run_plan_updates(
|
||||
record, staging, interjection_times, interjection_texts
|
||||
)
|
||||
processed += 1
|
||||
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
|
||||
@@ -0,0 +1,400 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(Module 1, Module 2) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. The returned list is
|
||||
intended to be passed as one ``{"type":"video", "video":<list>}``
|
||||
block to a Qwen-VL-compatible model that pools temporally itself.
|
||||
Empty list if no camera available.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
By default the *first* camera key is used for Module 1 (subtask
|
||||
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
# ``camera_keys`` covers both image- and video-stored cameras
|
||||
# (``video_keys`` is video-only). Some datasets declare cameras with
|
||||
# ``dtype=image``, which would otherwise look empty here and silently
|
||||
# disable Module 3 even though the videos are there.
|
||||
keys = list(getattr(self._meta, "camera_keys", None) or self._meta.video_keys or [])
|
||||
# Last-resort fallback: if metadata didn't surface anything but the
|
||||
# caller explicitly named a camera (``--vlm.camera_key=...``), trust
|
||||
# them — the key is by definition known to exist on the dataset.
|
||||
if not keys and self.camera_key:
|
||||
keys = [self.camera_key]
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
for i, img in zip(miss_indices, decoded):
|
||||
out[i] = img
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(
|
||||
self, episode_index: int, timestamps: list[float], camera_key: str
|
||||
) -> list[Any]:
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
try:
|
||||
return _decode_pyav_direct(video_path, shifted, self.tolerance_s)
|
||||
except Exception as exc:
|
||||
# Log loudly the first time decoding fails so silent
|
||||
# Module-3-no-op (every prompt skipped because frames_at returned
|
||||
# []) is debuggable from the job log instead of post-hoc parquet
|
||||
# inspection. Subsequent failures stay quiet.
|
||||
if not getattr(self, "_warned_decode_fail", False):
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"VideoFrameProvider._decode failed for episode=%s camera=%s "
|
||||
"video_path=%s: %s",
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
self._warned_decode_fail = True
|
||||
return []
|
||||
|
||||
|
||||
def _decode_pyav_direct(
|
||||
video_path: Any, timestamps: list[float], tolerance_s: float
|
||||
) -> list[Any]:
|
||||
"""Decode the requested timestamps from ``video_path`` using PyAV directly.
|
||||
|
||||
Bypasses ``lerobot.datasets.video_utils.decode_video_frames`` entirely
|
||||
because its "pyav" path actually goes through
|
||||
``decode_video_frames_torchvision`` → ``torchvision.io.VideoReader``,
|
||||
which was removed in torchvision >= 0.22 (the vllm/vllm-openai:latest
|
||||
container ships with torchvision 0.25). The annotation pipeline only
|
||||
needs a handful of PIL images per (episode, ts), so we can decode them
|
||||
with PyAV without any torch dependency at all.
|
||||
|
||||
Returns one ``PIL.Image`` per requested timestamp, in the same order.
|
||||
Any timestamp the decoder couldn't reach is silently dropped (mirrors
|
||||
the previous behaviour); callers filter ``None``/missing entries.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
from PIL import Image # noqa: PLC0415
|
||||
|
||||
if not timestamps:
|
||||
return []
|
||||
|
||||
targets = sorted(set(timestamps))
|
||||
seek_to = max(0.0, min(targets) - max(0.5, tolerance_s))
|
||||
|
||||
container = av.open(str(video_path))
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
# PyAV needs the seek target in stream timebase ticks.
|
||||
if stream.time_base is None:
|
||||
seek_pts = 0
|
||||
else:
|
||||
seek_pts = int(seek_to / float(stream.time_base))
|
||||
try:
|
||||
container.seek(seek_pts, any_frame=False, backward=True, stream=stream)
|
||||
except av.AVError:
|
||||
# Some streams reject the explicit seek; fall back to decoding from start.
|
||||
container.seek(0)
|
||||
|
||||
results: dict[float, Any] = {}
|
||||
target_iter = iter(targets)
|
||||
next_target = next(target_iter, None)
|
||||
for frame in container.decode(stream):
|
||||
if next_target is None:
|
||||
break
|
||||
ts = float(frame.pts * frame.time_base) if frame.pts is not None else None
|
||||
if ts is None:
|
||||
continue
|
||||
# Walk past targets we've already overshot — we keep the closest
|
||||
# frame within tolerance.
|
||||
while next_target is not None and ts >= next_target - tolerance_s:
|
||||
if abs(ts - next_target) <= tolerance_s or ts >= next_target:
|
||||
img = frame.to_image() # PIL.Image.Image (RGB)
|
||||
results.setdefault(next_target, img)
|
||||
next_target = next(target_iter, None)
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
return [results[ts] for ts in timestamps if ts in results]
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally.
|
||||
"""
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
timestamps = list(record.frame_timestamps)
|
||||
else:
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
if t_last <= t0:
|
||||
timestamps = [float(t0)] * n_frames
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||
return [{"type": "image", "image": img} for img in images]
|
||||
|
||||
|
||||
def to_video_block(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Wrap a list of PIL images as one Qwen-VL video block.
|
||||
|
||||
Returns ``[]`` when the list is empty, so the caller can splat the result
|
||||
into a content array without a separate emptiness check.
|
||||
"""
|
||||
if not images:
|
||||
return []
|
||||
return [{"type": "video", "video": list(images)}]
|
||||
|
||||
|
||||
def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]:
|
||||
"""Wrap a video file URL as one ``video_url`` block.
|
||||
|
||||
Used by the ``openai`` backend (transformers serve / vllm serve /
|
||||
ktransformers serve), where the server handles frame sampling.
|
||||
Returns ``[]`` when ``url`` is ``None`` so the caller can splat.
|
||||
"""
|
||||
if not url:
|
||||
return []
|
||||
return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}]
|
||||
|
||||
|
||||
def episode_clip_path(
|
||||
record: EpisodeRecord,
|
||||
provider: "VideoFrameProvider",
|
||||
cache_dir: Path,
|
||||
) -> Path | None:
|
||||
"""Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``.
|
||||
|
||||
Returns ``None`` if the dataset has no video tracks. Skips re-extract
|
||||
when the cached clip already exists. Re-encodes to H.264
|
||||
(libx264) so the resulting mp4 is decodable by every downstream
|
||||
video processor — stream-copy would inherit the source codec
|
||||
(often AV1 in modern LeRobot datasets), which vllm's libav build
|
||||
cannot decode.
|
||||
"""
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
if provider.camera_key is None:
|
||||
return None
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4"
|
||||
if out_path.exists() and out_path.stat().st_size > 0:
|
||||
return out_path
|
||||
ep = provider._meta.episodes[record.episode_index]
|
||||
from_timestamp = float(ep[f"videos/{provider.camera_key}/from_timestamp"])
|
||||
to_timestamp = float(ep[f"videos/{provider.camera_key}/to_timestamp"])
|
||||
src = provider.root / provider._meta.get_video_file_path(
|
||||
record.episode_index, provider.camera_key
|
||||
)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-ss",
|
||||
f"{from_timestamp:.3f}",
|
||||
"-to",
|
||||
f"{to_timestamp:.3f}",
|
||||
"-i",
|
||||
str(src),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True, timeout=300)
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
return out_path if out_path.exists() and out_path.stat().st_size > 0 else None
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .general_vqa import GeneralVqaModule
|
||||
from .interjections_and_speech import InterjectionsAndSpeechModule
|
||||
from .plan_subtasks_memory import PlanSubtasksMemoryModule
|
||||
|
||||
__all__ = [
|
||||
"GeneralVqaModule",
|
||||
"InterjectionsAndSpeechModule",
|
||||
"PlanSubtasksMemoryModule",
|
||||
]
|
||||
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 3: general VQA at a timed cadence.
|
||||
|
||||
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
||||
emission. For datasets with multiple cameras, every emission tick produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
whose schema depends on the question type. Malformed JSON triggers one
|
||||
retry inside :meth:`VlmClient.generate_json`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module3Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..validator import classify_vqa_answer
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]:
|
||||
"""Return the relative frame indices to anchor VQA emissions to.
|
||||
|
||||
For each emission tick (every ``1/hz`` seconds), we anchor ``k``
|
||||
consecutive frames starting at the tick. Ticks fall on the nearest
|
||||
available source frame timestamp.
|
||||
"""
|
||||
if hz <= 0 or k <= 0 or not frame_timestamps:
|
||||
return []
|
||||
t0 = frame_timestamps[0]
|
||||
t_last = frame_timestamps[-1]
|
||||
period = 1.0 / hz
|
||||
indices: list[int] = []
|
||||
t = t0
|
||||
while t <= t_last + 1e-9:
|
||||
# find the index of the nearest frame to t
|
||||
nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t))
|
||||
for offset in range(k):
|
||||
j = nearest_i + offset
|
||||
if j >= len(frame_timestamps):
|
||||
break
|
||||
if not indices or indices[-1] != j:
|
||||
indices.append(j)
|
||||
t += period
|
||||
# dedupe while preserving order
|
||||
seen: set[int] = set()
|
||||
deduped: list[int] = []
|
||||
for i in indices:
|
||||
if i in seen:
|
||||
continue
|
||||
seen.add(i)
|
||||
deduped.append(i)
|
||||
return deduped
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralVqaModule:
|
||||
"""Emit grounded VQA pairs at a timed cadence."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module3Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
if not record.frame_timestamps:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:vqa")
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — emit nothing rather than producing
|
||||
# untagged rows that would fail validation. Surface a loud one-
|
||||
# time warning so this is never silently a no-op.
|
||||
if not getattr(self, "_warned_no_camera", False):
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"Module 3 (VQA) found no cameras on the frame provider — "
|
||||
"every episode will emit zero VQA rows. Check that the "
|
||||
"dataset declares observation.images.* features in "
|
||||
"meta/info.json; passing --vlm.camera_key=<key> at the "
|
||||
"CLI now also seeds the cameras list as a fallback."
|
||||
)
|
||||
self._warned_no_camera = True
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
question, answer = qa
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras Module 3 should iterate per emission tick.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
question = result.get("question")
|
||||
answer = result.get("answer")
|
||||
if not isinstance(question, str) or not question.strip():
|
||||
return None
|
||||
if not isinstance(answer, dict):
|
||||
return None
|
||||
# The validator will enforce shape; here we just sanity-check that the
|
||||
# answer matches *some* known shape so we can drop garbage early.
|
||||
if classify_vqa_answer(answer) is None:
|
||||
return None
|
||||
return question.strip(), answer
|
||||
|
||||
def _generate_one(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return self._postprocess(result)
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 2: interjections + paired speech (EVENT styles + speech atoms).
|
||||
|
||||
Two sub-passes:
|
||||
|
||||
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
||||
canonical task). No interjection row — the canonical task is already the
|
||||
user utterance from ``meta/tasks.parquet``.
|
||||
|
||||
2. For mid-episode interruptions, emit a co-timestamped pair:
|
||||
{role:user, style:interjection, content:<text>}
|
||||
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
||||
Both rows go in ``language_events`` at the same timestamp.
|
||||
|
||||
Module 1's :meth:`run_plan_updates` reuses Module 2's interjection
|
||||
timestamps to refresh the ``plan`` row at the same instant.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..config import Module2Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
|
||||
|
||||
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
return float(min(frame_timestamps, key=lambda f: abs(f - t)))
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterjectionsAndSpeechModule:
|
||||
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module2Config
|
||||
seed: int = 1729
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
if record.frame_timestamps:
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
initial = self._initial_speech(record)
|
||||
if initial:
|
||||
rows.append(speech_atom(t0, initial))
|
||||
# Pull Module 1's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. Module 1 ran first.
|
||||
subtask_spans = self._read_subtask_spans(staging)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("module_2", rows)
|
||||
|
||||
@staticmethod
|
||||
def _read_subtask_spans(staging: EpisodeStaging) -> list[dict[str, Any]]:
|
||||
rows = [r for r in staging.read("module_1") if r.get("style") == "subtask"]
|
||||
rows.sort(key=lambda r: float(r["timestamp"]))
|
||||
spans: list[dict[str, Any]] = []
|
||||
last_t: float | None = None
|
||||
for r in rows:
|
||||
t = float(r["timestamp"])
|
||||
if last_t is not None and spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
last_t = t
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
for span in spans:
|
||||
if float(span["start"]) <= t:
|
||||
current = span.get("text")
|
||||
else:
|
||||
break
|
||||
return current
|
||||
|
||||
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
||||
prompt = load_prompt("module_2_initial_speech").format(
|
||||
episode_task=record.episode_task,
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
||||
text = result["text"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _mid_episode_interjections(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate interjections aligned with the actual demo trajectory.
|
||||
|
||||
Teleop data is frozen — the robot already executed every step in
|
||||
the video. A *counterfactual* interjection like "actually skip
|
||||
the wipe" contradicts what then happens in the video, which is
|
||||
what qwen36moe-10/11 surfaced as low-quality interjections.
|
||||
|
||||
Instead, anchor every interjection at a subtask boundary and
|
||||
write it as a natural user request for the *upcoming* subtask.
|
||||
The robot's visible next behavior IS the interjection's effect,
|
||||
so the training signal stays consistent: interjection text →
|
||||
plan refresh → action stream all line up.
|
||||
"""
|
||||
if self.config.max_interjections_per_episode <= 0:
|
||||
return []
|
||||
if len(subtask_spans) < 2:
|
||||
# Need at least one transition (subtask 0 → subtask 1).
|
||||
return []
|
||||
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
||||
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
||||
|
||||
# Boundaries: the start time of every subtask except the first
|
||||
# (which is just t0 and is covered by the initial-task speech atom).
|
||||
boundaries: list[tuple[float, str, str]] = []
|
||||
for i in range(1, len(subtask_spans)):
|
||||
ts = float(subtask_spans[i]["start"])
|
||||
if ts < self.config.interjection_min_t:
|
||||
continue
|
||||
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
||||
next_text = (subtask_spans[i].get("text") or "").strip()
|
||||
if not next_text:
|
||||
continue
|
||||
boundaries.append((ts, prev_text, next_text))
|
||||
if not boundaries:
|
||||
return []
|
||||
|
||||
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
||||
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for t, prev_subtask, next_subtask in chosen:
|
||||
t_snap = _snap_to_frame(t, record.frame_timestamps)
|
||||
# Window straddles the boundary so the VLM sees the end of the
|
||||
# previous subtask and the start of the next one — same
|
||||
# conditioning the policy will see at training time.
|
||||
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
||||
prompt = load_prompt("module_2_interjection").format(
|
||||
episode_task=record.episode_task,
|
||||
prev_subtask=prev_subtask or "(starting from initial state)",
|
||||
next_subtask=next_subtask,
|
||||
timestamp=t_snap,
|
||||
window_seconds=self.config.interjection_window_seconds,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, window_ts)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
interjection_text = result.get("interjection")
|
||||
speech_text = result.get("speech")
|
||||
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
||||
continue
|
||||
if not isinstance(speech_text, str) or not speech_text.strip():
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": interjection_text.strip(),
|
||||
"style": "interjection",
|
||||
"timestamp": t_snap,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
out.append(speech_atom(t_snap, speech_text.strip()))
|
||||
return out
|
||||
|
||||
def _window_timestamps(
|
||||
self, t_anchor: float, frame_timestamps: Sequence[float]
|
||||
) -> list[float]:
|
||||
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
||||
|
||||
The window straddles the subtask boundary the interjection sits
|
||||
on: roughly half the frames cover the end of the previous
|
||||
subtask, half cover the start of the next one. The VLM therefore
|
||||
sees BOTH what just finished AND what's about to start, which is
|
||||
the conditioning we need to write a natural "now please do X"
|
||||
request that matches the visible upcoming behavior.
|
||||
"""
|
||||
if not frame_timestamps:
|
||||
return [t_anchor]
|
||||
n = max(1, int(self.config.interjection_window_frames))
|
||||
if n == 1:
|
||||
return [t_anchor]
|
||||
window = float(self.config.interjection_window_seconds)
|
||||
step = window / max(1, n - 1)
|
||||
# Center the window on the anchor so half lands before, half after.
|
||||
start_offset = -window / 2.0
|
||||
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
||||
last_ts = float(frame_timestamps[-1])
|
||||
snapped: list[float] = []
|
||||
seen: set[float] = set()
|
||||
for tgt in targets:
|
||||
clamped = min(last_ts, max(0.0, tgt))
|
||||
t = _snap_to_frame(clamped, frame_timestamps)
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
snapped.append(t)
|
||||
return snapped or [t_anchor]
|
||||
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1: subtask decomposition + plan + memory (PERSISTENT styles)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ..config import Module1Config
|
||||
from ..frames import (
|
||||
FrameProvider,
|
||||
VideoFrameProvider,
|
||||
episode_clip_path,
|
||||
null_provider,
|
||||
to_video_block,
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
|
||||
def _snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp."""
|
||||
if not frame_timestamps:
|
||||
return float(t)
|
||||
nearest = min(frame_timestamps, key=lambda f: abs(f - t))
|
||||
return float(nearest)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanSubtasksMemoryModule:
|
||||
"""Generate subtask spans, plan, and memory rows.
|
||||
|
||||
All output is persistent (lives in ``language_persistent``):
|
||||
|
||||
- ``subtask`` rows: one per span, stamped at the span's *start* timestamp
|
||||
(snapped to an exact frame).
|
||||
- ``plan`` rows: emitted at ``t=0``; refreshed at every interjection
|
||||
timestamp via :meth:`run_plan_updates` (called by the executor after
|
||||
Module 2 completes).
|
||||
- ``memory`` rows: emitted at each subtask boundary (= subtask start
|
||||
timestamp from the second subtask onward).
|
||||
"""
|
||||
|
||||
vlm: VlmClient
|
||||
config: Module1Config
|
||||
frame_provider: FrameProvider = field(default_factory=null_provider)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.config.enabled
|
||||
|
||||
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
||||
rows: list[dict[str, Any]] = []
|
||||
# Resolve the task that drives every other Module-1 prompt. May be
|
||||
# the canonical ``record.episode_task`` (default), or a fresh
|
||||
# description derived from the video when the canonical task is
|
||||
# empty / placeholder / forced-off (see Module1Config.derive_task_*).
|
||||
effective_task = self._resolve_effective_task(record)
|
||||
# ``task_aug`` rows at t=0 (role=user), one per rephrasing — the
|
||||
# PR 1 renderer rotates ``${task}`` deterministically through them
|
||||
# so the policy sees diverse phrasings during training.
|
||||
t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0
|
||||
if self.config.n_task_rephrasings > 0 and effective_task:
|
||||
rephrasings = self._generate_task_rephrasings(
|
||||
effective_task, n=self.config.n_task_rephrasings
|
||||
)
|
||||
# Always include the effective task itself as the first variant
|
||||
# so the rotation is guaranteed to cover the source-of-truth
|
||||
# phrasing, not just synthetic alternatives.
|
||||
seen: set[str] = set()
|
||||
ordered = [effective_task, *rephrasings]
|
||||
for phrasing in ordered:
|
||||
key = phrasing.strip()
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
rows.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": key,
|
||||
"style": "task_aug",
|
||||
"timestamp": t0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
|
||||
subtask_spans = self._generate_subtasks(record, task=effective_task)
|
||||
# subtask rows
|
||||
for span in subtask_spans:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": span["text"],
|
||||
"style": "subtask",
|
||||
"timestamp": _snap_to_frame(span["start"], record.frame_timestamps),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# plan row at t=0
|
||||
plan_text = self._generate_plan(record, subtask_spans, task=effective_task)
|
||||
if plan_text is not None:
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": float(t0),
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
# memory rows at every subtask boundary except the very first start
|
||||
prior_memory = ""
|
||||
for i, span in enumerate(subtask_spans[1:], start=1):
|
||||
completed = subtask_spans[i - 1]["text"]
|
||||
remaining = [s["text"] for s in subtask_spans[i:]]
|
||||
mem_text = self._generate_memory(
|
||||
record, prior_memory, completed, remaining, task=effective_task
|
||||
)
|
||||
if mem_text:
|
||||
ts = _snap_to_frame(span["start"], record.frame_timestamps)
|
||||
rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": mem_text,
|
||||
"style": "memory",
|
||||
"timestamp": ts,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
prior_memory = mem_text
|
||||
staging.write("module_1", rows)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task derivation + rephrasings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_PLACEHOLDER_TASKS: frozenset[str] = frozenset(
|
||||
{
|
||||
"debug",
|
||||
"test",
|
||||
"tbd",
|
||||
"todo",
|
||||
"n/a",
|
||||
"na",
|
||||
"untitled",
|
||||
"unnamed",
|
||||
"default",
|
||||
"placeholder",
|
||||
}
|
||||
)
|
||||
|
||||
def _resolve_effective_task(self, record: EpisodeRecord) -> str:
|
||||
"""Decide which task string drives Module 1 for this episode.
|
||||
|
||||
Returns the user-supplied ``record.episode_task`` unless
|
||||
``derive_task_from_video`` says otherwise (see config docstring).
|
||||
Falls back gracefully to the canonical task if video derivation
|
||||
fails.
|
||||
"""
|
||||
canonical = (record.episode_task or "").strip()
|
||||
mode = (self.config.derive_task_from_video or "off").strip().lower()
|
||||
if mode == "always":
|
||||
derived = self._derive_task_from_video(record)
|
||||
return derived or canonical
|
||||
if mode == "if_short" and self._task_seems_bad(canonical):
|
||||
derived = self._derive_task_from_video(record)
|
||||
if derived:
|
||||
return derived
|
||||
return canonical
|
||||
|
||||
def _task_seems_bad(self, task: str) -> bool:
|
||||
if not task:
|
||||
return True
|
||||
if len(task.split()) < int(self.config.derive_task_min_words):
|
||||
return True
|
||||
if task.lower() in self._PLACEHOLDER_TASKS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _derive_task_from_video(self, record: EpisodeRecord) -> str | None:
|
||||
"""Ask the VLM "what is this video about" with no task hint at all."""
|
||||
prompt = load_prompt("module_1_video_task")
|
||||
video_block = self._episode_video_block(record)
|
||||
content = [*video_block, {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("task"), str):
|
||||
text = result["task"].strip()
|
||||
if text:
|
||||
return text
|
||||
return None
|
||||
|
||||
def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]:
|
||||
"""Generate ``n`` text-only paraphrases of ``base_task``."""
|
||||
if n <= 0 or not base_task:
|
||||
return []
|
||||
prompt = load_prompt("module_1_task_rephrasings").format(
|
||||
base_task=base_task, n=n
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if not isinstance(result, dict):
|
||||
return []
|
||||
raw = result.get("rephrasings")
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
out: list[str] = []
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
cleaned = item.strip().strip('"').strip("'")
|
||||
if cleaned:
|
||||
out.append(cleaned)
|
||||
return out[:n]
|
||||
|
||||
def _episode_video_block(self, record: EpisodeRecord) -> list[dict[str, Any]]:
|
||||
"""Same video block ``_generate_subtasks`` builds — extracted helper."""
|
||||
if not record.frame_timestamps:
|
||||
return []
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = episode_clip_path(record, self.frame_provider, cache_dir)
|
||||
return (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
target_count = max(
|
||||
1, int(round(episode_duration * self.config.frames_per_second))
|
||||
)
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
return to_video_block(video_frames)
|
||||
|
||||
def run_plan_updates(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging: EpisodeStaging,
|
||||
interjection_times: Sequence[float],
|
||||
interjection_texts: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Append additional ``plan`` rows at every interjection timestamp.
|
||||
|
||||
Plans refresh ONLY on user interjections — subtask generation
|
||||
runs ~1 Hz at inference, but plan re-emission is event-driven.
|
||||
Now also forwards the interjection's own text into the prompt so
|
||||
the refreshed plan can actually reflect the user's correction
|
||||
(the previous version told the model "an interjection happened"
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("module_1")
|
||||
spans = self._reconstruct_subtasks_from_rows(existing)
|
||||
already_planned: set[float] = {
|
||||
float(r["timestamp"]) for r in existing if r.get("style") == "plan"
|
||||
}
|
||||
new_rows = list(existing)
|
||||
|
||||
texts: list[str | None] = (
|
||||
[None] * len(interjection_times)
|
||||
if interjection_texts is None
|
||||
else [str(t) if t else None for t in interjection_texts]
|
||||
)
|
||||
for raw_t, inter_text in zip(interjection_times, texts):
|
||||
t = _snap_to_frame(raw_t, record.frame_timestamps)
|
||||
if t in already_planned:
|
||||
continue
|
||||
already_planned.add(t)
|
||||
plan_text = self._generate_plan(
|
||||
record, spans, refresh_t=t, interjection=inter_text
|
||||
)
|
||||
if plan_text is not None:
|
||||
new_rows.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": plan_text,
|
||||
"style": "plan",
|
||||
"timestamp": t,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_1", new_rows)
|
||||
|
||||
@staticmethod
|
||||
def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
last_t: float | None = None
|
||||
for row in sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
):
|
||||
t = float(row["timestamp"])
|
||||
if last_t is not None:
|
||||
out[-1]["end"] = t
|
||||
out.append({"text": row.get("content") or "", "start": t, "end": t})
|
||||
last_t = t
|
||||
return out
|
||||
|
||||
def _generate_subtasks(
|
||||
self, record: EpisodeRecord, *, task: str | None = None
|
||||
) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0]
|
||||
prompt = load_prompt("module_1_subtasks").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
min_subtask_seconds=self.config.min_subtask_seconds,
|
||||
max_steps=self.config.plan_max_steps,
|
||||
episode_duration=f"{episode_duration:.3f}",
|
||||
)
|
||||
if self.config.use_video_url and isinstance(self.frame_provider, VideoFrameProvider):
|
||||
cache_dir = Path(self.frame_provider.root) / ".annotate_staging" / ".video_clips"
|
||||
clip = episode_clip_path(record, self.frame_provider, cache_dir)
|
||||
video_block = (
|
||||
to_video_url_block(f"file://{clip}", fps=self.config.use_video_url_fps)
|
||||
if clip is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
target_count = max(
|
||||
1,
|
||||
int(round(episode_duration * self.config.frames_per_second)),
|
||||
)
|
||||
target_count = min(target_count, self.config.max_video_frames)
|
||||
video_frames = self.frame_provider.video_for_episode(record, target_count)
|
||||
video_block = to_video_block(video_frames)
|
||||
content = [*video_block, {"type": "text", "text": prompt}]
|
||||
messages = [{"role": "user", "content": content}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
spans = result.get("subtasks") if isinstance(result, dict) else None
|
||||
if not spans:
|
||||
return []
|
||||
# clamp to [t0, t_last] and sort
|
||||
t0 = record.frame_timestamps[0]
|
||||
t_last = record.frame_timestamps[-1]
|
||||
cleaned: list[dict[str, Any]] = []
|
||||
for span in spans:
|
||||
try:
|
||||
start = float(span["start"])
|
||||
end = float(span["end"])
|
||||
text = str(span["text"]).strip()
|
||||
except (KeyError, ValueError, TypeError):
|
||||
continue
|
||||
start = max(t0, min(start, t_last))
|
||||
end = max(t0, min(end, t_last))
|
||||
if end < start:
|
||||
start, end = end, start
|
||||
if not text:
|
||||
continue
|
||||
cleaned.append({"text": text, "start": start, "end": end})
|
||||
cleaned.sort(key=lambda s: s["start"])
|
||||
return cleaned
|
||||
|
||||
def _generate_plan(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
subtask_spans: Sequence[dict[str, Any]],
|
||||
*,
|
||||
refresh_t: float | None = None,
|
||||
interjection: str | None = None,
|
||||
task: str | None = None,
|
||||
) -> str | None:
|
||||
if not subtask_spans:
|
||||
return None
|
||||
subtasks_text = "\n".join(f"- {s['text']}" for s in subtask_spans)
|
||||
prompt = load_prompt("module_1_plan").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
subtasks_text=subtasks_text,
|
||||
plan_max_steps=self.config.plan_max_steps,
|
||||
)
|
||||
if refresh_t is not None:
|
||||
# ``current_subtask`` is the span the refresh time falls into,
|
||||
# so the model knows where in the demonstration the planner is
|
||||
# standing when it re-emits.
|
||||
current_subtask = ""
|
||||
for span in subtask_spans:
|
||||
if float(span["start"]) <= refresh_t and (
|
||||
"end" not in span or float(span["end"]) > refresh_t
|
||||
):
|
||||
current_subtask = span.get("text", "")
|
||||
break
|
||||
if interjection:
|
||||
prompt += (
|
||||
f"\n\n(Plan refresh at t={refresh_t:.2f}s after a user "
|
||||
f"interjection: {interjection!r}. Current subtask just "
|
||||
f"before the interjection: {current_subtask!r}. Update "
|
||||
f"the plan so it reflects the interjection — drop or "
|
||||
f"reorder steps as needed; do not just restate.)\n"
|
||||
)
|
||||
else:
|
||||
# Refresh without an interjection text: still tell the model
|
||||
# where in the episode the plan stands so the re-emission
|
||||
# is grounded. Should be rare — plan refreshes are
|
||||
# interjection-driven by design.
|
||||
prompt += (
|
||||
f"\n\n(Plan refresh at t={refresh_t:.2f}s. Current "
|
||||
f"subtask: {current_subtask!r}.)\n"
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("plan"), str):
|
||||
return result["plan"].strip()
|
||||
return None
|
||||
|
||||
def _generate_memory(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
prior_memory: str,
|
||||
completed: str,
|
||||
remaining: Sequence[str],
|
||||
*,
|
||||
task: str | None = None,
|
||||
) -> str:
|
||||
prompt = load_prompt("module_1_memory").format(
|
||||
episode_task=(task if task is not None else record.episode_task),
|
||||
prior_memory=prior_memory or "(none)",
|
||||
completed_subtask=completed,
|
||||
remaining_subtasks=", ".join(remaining) if remaining else "(none)",
|
||||
)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
if isinstance(result, dict) and isinstance(result.get("memory"), str):
|
||||
return result["memory"].strip()
|
||||
return ""
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prompt templates loaded as plain text.
|
||||
|
||||
One file per use site. Templates use ``str.format(**vars)`` substitution; we
|
||||
intentionally avoid jinja2 here so the templates remain inspectable in
|
||||
plain editors and roundtrip cleanly through ``ruff format``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load(name: str) -> str:
|
||||
"""Read prompt template ``name.txt`` from the ``prompts/`` directory."""
|
||||
path = _DIR / f"{name}.txt"
|
||||
return path.read_text(encoding="utf-8")
|
||||
@@ -0,0 +1,31 @@
|
||||
You are updating the robot's compressed semantic memory at the boundary of
|
||||
a completed subtask.
|
||||
|
||||
Reference (verbatim from MEM, Torne 2026):
|
||||
"Remove or compress information in the language memory whenever
|
||||
appropriate. Keep ONLY the minimal set of relevant information for future
|
||||
task execution. Specific object attributes (colors, precise quantities of
|
||||
each item) get discarded when their details won't affect subsequent
|
||||
actions. Functional outcomes (where items went, how many) are preserved."
|
||||
|
||||
Concrete example from MEM:
|
||||
Before: "I put a light green bowl, a dark blue bowl and a bright yellow
|
||||
bowl into the top right cabinet"
|
||||
After: "I placed three bowls in the top right cabinet"
|
||||
|
||||
Episode task: "{episode_task}"
|
||||
Previous memory: {prior_memory}
|
||||
Just-completed subtask: "{completed_subtask}"
|
||||
Remaining subtasks (for relevance judgement only): {remaining_subtasks}
|
||||
|
||||
Update the memory as a compact state note.
|
||||
|
||||
Rules:
|
||||
- Keep only facts needed later.
|
||||
- Keep WHAT changed; drop HOW it was done.
|
||||
- Use fragments when clear.
|
||||
- Prefer: "bowl in box; lid still open"
|
||||
- Avoid: "The robot placed the bowl into the box and the lid remains open."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "memory": "<brief state note>" }}
|
||||
@@ -0,0 +1,18 @@
|
||||
You are the high-level planner for a robot demonstrating: "{episode_task}".
|
||||
|
||||
Given the subtask decomposition below, write a compact hierarchical PLAN.
|
||||
Use short imperative fragments, like pi0.7 context prompts.
|
||||
|
||||
Subtasks for context:
|
||||
{subtasks_text}
|
||||
|
||||
Authoring rules:
|
||||
- 3 to {plan_max_steps} steps.
|
||||
- Each step is one logical chunk, not one motion.
|
||||
- Steps must be in execution order.
|
||||
- Brief commands, not full sentences.
|
||||
- Prefer: "open air fryer"; avoid: "The robot should open the air fryer."
|
||||
- Plain text, no markdown headers.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "plan": "1. ...\n2. ...\n3. ..." }}
|
||||
@@ -0,0 +1,34 @@
|
||||
You are labeling a teleoperated robot demonstration.
|
||||
|
||||
The user originally asked: "{episode_task}"
|
||||
|
||||
You are shown the entire demonstration as a single video. Watch the
|
||||
whole clip, then segment it into a list of consecutive atomic subtasks
|
||||
the robot performs. Write compact action labels, not prose.
|
||||
|
||||
Authoring rules — based on Hi Robot (Shi 2025) atom granularity and
|
||||
pi0.7 (Physical Intelligence 2025) compact context prompts:
|
||||
|
||||
- Each subtask is one atomic skill the low-level policy can execute,
|
||||
e.g. "pick up one piece of lettuce", "place the bowl into the box",
|
||||
"move the right arm to the left".
|
||||
- Capture HOW when useful, but keep it brief — e.g. prefer
|
||||
"grasp the handle of the sponge with the left hand" to "pick up the
|
||||
sponge".
|
||||
- Use verb phrases, not full sentences.
|
||||
- Subtasks are non-overlapping and cover the full episode in order.
|
||||
Choose the cut points yourself based on what you see in the video
|
||||
(gripper open/close events, contact, regrasps, transitions).
|
||||
- Each subtask spans at least {min_subtask_seconds} seconds.
|
||||
- Do not exceed {max_steps} subtasks total.
|
||||
- Every subtask's [start_time, end_time] must lie within
|
||||
[0.0, {episode_duration}] seconds.
|
||||
|
||||
Output strictly valid JSON of shape:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{"text": "<how-not-what>", "start": <float>, "end": <float>}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating training data for a Hi Robot-style policy. We need
|
||||
{n} alternative phrasings of the same robot task so the policy sees
|
||||
diverse user prompts during training instead of the same canonical
|
||||
string repeated every frame.
|
||||
|
||||
Original task:
|
||||
"{base_task}"
|
||||
|
||||
Generate exactly {n} alternative phrasings of the same task. Vary:
|
||||
|
||||
- formality (casual / polite / curt)
|
||||
- verbosity (mostly short imperative; occasional polite request)
|
||||
- word choice (synonyms, different verbs)
|
||||
- sentence structure (imperative / question / suggestion)
|
||||
|
||||
Hard rules:
|
||||
- Each phrasing MUST preserve the exact meaning of the original task.
|
||||
Do not change which object is involved, the destination, or the
|
||||
action. Do not add extra steps. Do not invent new objects.
|
||||
- Each phrasing must be a short phrase or sentence, plain prose, no
|
||||
markdown, no quotes, no list numbers.
|
||||
- Phrasings must be distinct — no near-duplicates.
|
||||
- Output exactly {n} entries.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"rephrasings": [
|
||||
"<phrasing 1>",
|
||||
"<phrasing 2>",
|
||||
...
|
||||
]
|
||||
}}
|
||||
@@ -0,0 +1,17 @@
|
||||
The video above shows a robot manipulation episode in full. Look at
|
||||
the entire video and describe in ONE concise sentence what the robot
|
||||
is doing.
|
||||
|
||||
Rules:
|
||||
- One sentence, in natural English, like a user instruction.
|
||||
- Capture the goal of the demonstration, not low-level motions.
|
||||
Example: "place the yellow cube into the red bin" — not "move the
|
||||
end-effector down 5cm and close the gripper".
|
||||
- 4 to 15 words. Plain prose, no markdown, no bullets, no quotes.
|
||||
- Do not invent objects or actions that aren't visible.
|
||||
- Do not output anything other than the JSON object below.
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"task": "<single concise sentence describing what the robot does in this video>"
|
||||
}}
|
||||
@@ -0,0 +1,12 @@
|
||||
The user just asked the robot: "{episode_task}".
|
||||
|
||||
Generate a short verbal acknowledgement the robot would speak back before
|
||||
beginning the task. Style: compact, confident, friendly.
|
||||
|
||||
Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.",
|
||||
"OK, starting with the sponge.", "Got it.".
|
||||
|
||||
Prefer very short replies: "Got it.", "On it.", "OK."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{ "text": "<the spoken acknowledgement>" }}
|
||||
@@ -0,0 +1,46 @@
|
||||
You are generating training data for a Hi Robot-style hierarchical
|
||||
robot policy. The robot in this demonstration has ALREADY executed
|
||||
every step shown in the video — we cannot retroactively change the
|
||||
action stream. To keep training data consistent with the video, the
|
||||
"interjection" must align with what the robot is *about to do next* in
|
||||
the demonstration, framed as a natural mid-task user request.
|
||||
|
||||
The episode's overall task: "{episode_task}".
|
||||
|
||||
The images above show roughly {window_seconds:.1f} seconds straddling a
|
||||
subtask boundary in the demonstration:
|
||||
|
||||
- Subtask the robot just finished: "{prev_subtask}"
|
||||
- Subtask the robot is about to start: "{next_subtask}"
|
||||
- Time into episode: {timestamp:.2f}s
|
||||
|
||||
Write ONE compact interjection the user would naturally say at this
|
||||
moment to prompt / confirm / encourage the robot to do "{next_subtask}".
|
||||
Keep it like a mid-task coaching cue, not a full instruction paragraph.
|
||||
Also write the robot's compact verbal acknowledgement.
|
||||
|
||||
Hard rules:
|
||||
|
||||
- The interjection MUST be consistent with the next subtask. The user
|
||||
cannot ask for something different from what the robot then does in
|
||||
the video. If you're tempted to say "actually skip X" or "do Y
|
||||
instead", DO NOT — those would contradict the demonstration.
|
||||
- The interjection must reference an object, location, or action that
|
||||
is plausible given the visible scene and the next subtask text.
|
||||
- One short phrase or sentence each. Conversational, not robotic.
|
||||
- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}."
|
||||
- Keep robot speech very short: "OK.", "On it.", "Doing that."
|
||||
|
||||
Style examples (vary the phrasing — don't reuse these verbatim):
|
||||
- "Now go ahead and {next_subtask}."
|
||||
- "Great, can you {next_subtask} next?"
|
||||
- "{next_subtask}, please."
|
||||
- "Before you continue, please {next_subtask}."
|
||||
- "Looking good — {next_subtask} now."
|
||||
- "Okay, {next_subtask}."
|
||||
|
||||
Output strictly valid JSON:
|
||||
{{
|
||||
"interjection": "<short cue from the user, asking for the next subtask>",
|
||||
"speech": "<short robot acknowledgement>"
|
||||
}}
|
||||
@@ -0,0 +1,32 @@
|
||||
You are generating a frame-grounded visual question/answer pair for
|
||||
chain-of-thought training. Reference: ECoT (Zawalski 2024) and Steerable
|
||||
Policies — both train policies on grounded features such as bounding box
|
||||
pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
attribute => {{"label": "<obj>", "attribute": "<color|shape|state|...>",
|
||||
"value": "<observed value>"}}
|
||||
|
||||
spatial => {{"subject": "<obj>", "relation": "<left_of|right_of|on|in|"
|
||||
"above|below|near>", "object": "<obj>"}}
|
||||
|
||||
Generate a question of type "{question_type}". Output strictly valid JSON:
|
||||
|
||||
{{
|
||||
"question": "<short, frame-grounded question>",
|
||||
"answer": <object whose shape matches the schema above>
|
||||
}}
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Datatrove-shaped reader.
|
||||
|
||||
The reader walks ``data/chunk-*/file-*.parquet`` and yields one record per
|
||||
episode containing:
|
||||
|
||||
- ``episode_index``: int
|
||||
- ``frame_timestamps``: tuple[float, ...]
|
||||
- ``frame_indices``: tuple[int, ...]
|
||||
- ``episode_task``: str (canonical task from ``meta/tasks.parquet``)
|
||||
- ``data_path``: pathlib.Path of the source parquet shard
|
||||
- ``frames_df``: pandas.DataFrame slice for the episode (only loaded on demand)
|
||||
|
||||
This shape lets each module operate per-episode without loading all parquet
|
||||
rows into memory at once. It deliberately does not depend on datatrove —
|
||||
datatrove integration wraps this generator inside a ``PipelineStep`` in
|
||||
:mod:`.executor`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_TASKS_PATH
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeRecord:
|
||||
"""Per-episode record yielded by the reader."""
|
||||
|
||||
episode_index: int
|
||||
episode_task: str
|
||||
frame_timestamps: tuple[float, ...]
|
||||
frame_indices: tuple[int, ...]
|
||||
data_path: Path
|
||||
row_offset: int # row offset within the parquet file where this episode starts
|
||||
row_count: int # number of rows for this episode
|
||||
|
||||
def frames_df(self): # type: ignore[no-untyped-def]
|
||||
"""Lazy-load the pandas slice for this episode."""
|
||||
import pandas as pd # noqa: PLC0415 - deferred for optional dataset extra
|
||||
|
||||
table = pq.read_table(self.data_path)
|
||||
df: pd.DataFrame = table.to_pandas()
|
||||
slice_ = df.iloc[self.row_offset : self.row_offset + self.row_count].reset_index(drop=True)
|
||||
return slice_
|
||||
|
||||
|
||||
def _load_tasks_lookup(root: Path) -> dict[int, str]:
|
||||
tasks_path = root / DEFAULT_TASKS_PATH
|
||||
if not tasks_path.exists():
|
||||
return {}
|
||||
table = pq.read_table(tasks_path)
|
||||
cols = {name: table.column(name).to_pylist() for name in table.column_names}
|
||||
if "task_index" in cols and "task" in cols:
|
||||
return dict(zip(cols["task_index"], cols["task"], strict=True))
|
||||
raise ValueError(f"meta/tasks.parquet at {tasks_path} missing 'task_index' or 'task'")
|
||||
|
||||
|
||||
def iter_episodes(root: Path, *, only_episodes: tuple[int, ...] | None = None) -> Iterator[EpisodeRecord]:
|
||||
"""Yield :class:`EpisodeRecord` for every episode under ``root/data/``.
|
||||
|
||||
Episodes are yielded in ascending ``episode_index`` order. The reader does
|
||||
not assume a specific chunk/file layout: it scans every ``*.parquet``
|
||||
under ``data/`` and groups by ``episode_index``.
|
||||
"""
|
||||
tasks = _load_tasks_lookup(root)
|
||||
data_dir = root / "data"
|
||||
parquet_files = sorted(data_dir.rglob("*.parquet"))
|
||||
|
||||
only_set = set(only_episodes) if only_episodes is not None else None
|
||||
|
||||
for path in parquet_files:
|
||||
yield from _iter_one_path(path, tasks, only_set)
|
||||
|
||||
|
||||
def _iter_one_path(path: Path, tasks: dict[int, str], only_set: set[int] | None) -> Iterator[EpisodeRecord]:
|
||||
table = pq.read_table(path)
|
||||
names = table.column_names
|
||||
if "episode_index" not in names:
|
||||
return
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
timestamp_col = (
|
||||
table.column("timestamp").to_pylist() if "timestamp" in names else [0.0] * len(episode_col)
|
||||
)
|
||||
frame_col = (
|
||||
table.column("frame_index").to_pylist() if "frame_index" in names else list(range(len(episode_col)))
|
||||
)
|
||||
task_col = table.column("task_index").to_pylist() if "task_index" in names else None
|
||||
|
||||
def _build(
|
||||
ep: int,
|
||||
start: int,
|
||||
end: int,
|
||||
task_idx: int | None,
|
||||
ts_buf: list[float],
|
||||
fi_buf: list[int],
|
||||
) -> EpisodeRecord | None:
|
||||
if only_set is not None and ep not in only_set:
|
||||
return None
|
||||
task = tasks.get(task_idx, "") if task_idx is not None else ""
|
||||
return EpisodeRecord(
|
||||
episode_index=ep,
|
||||
episode_task=task,
|
||||
frame_timestamps=tuple(ts_buf),
|
||||
frame_indices=tuple(fi_buf),
|
||||
data_path=path,
|
||||
row_offset=start,
|
||||
row_count=end - start,
|
||||
)
|
||||
|
||||
cur_ep: int | None = None
|
||||
start_offset = 0
|
||||
ts_buf: list[float] = []
|
||||
fi_buf: list[int] = []
|
||||
cur_task_idx: int | None = None
|
||||
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
rec = _build(cur_ep, start_offset, i, cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
cur_ep = ep
|
||||
start_offset = i
|
||||
ts_buf = [timestamp_col[i]]
|
||||
fi_buf = [frame_col[i]]
|
||||
cur_task_idx = task_col[i] if task_col is not None else None
|
||||
else:
|
||||
ts_buf.append(timestamp_col[i])
|
||||
fi_buf.append(frame_col[i])
|
||||
|
||||
if cur_ep is not None:
|
||||
rec = _build(cur_ep, start_offset, len(episode_col), cur_task_idx, ts_buf, fi_buf)
|
||||
if rec is not None:
|
||||
yield rec
|
||||
|
||||
|
||||
def gather_data_paths(root: Path) -> list[Path]:
|
||||
"""Return every ``data/chunk-*/file-*.parquet`` path under ``root``."""
|
||||
return sorted((root / "data").rglob("*.parquet"))
|
||||
|
||||
|
||||
def episode_offsets_per_path(path: Path) -> dict[int, tuple[int, int]]:
|
||||
"""Return ``{episode_index: (row_offset, row_count)}`` for one parquet."""
|
||||
table = pq.read_table(path, columns=["episode_index"])
|
||||
episode_col = table.column("episode_index").to_pylist()
|
||||
out: dict[int, tuple[int, int]] = {}
|
||||
cur_ep: int | None = None
|
||||
start = 0
|
||||
for i, ep in enumerate(episode_col):
|
||||
if cur_ep is None:
|
||||
cur_ep = ep
|
||||
start = i
|
||||
continue
|
||||
if ep != cur_ep:
|
||||
out[cur_ep] = (start, i - start)
|
||||
cur_ep = ep
|
||||
start = i
|
||||
if cur_ep is not None:
|
||||
out[cur_ep] = (start, len(episode_col) - start)
|
||||
return out
|
||||
|
||||
|
||||
def keyframe_indices(record: EpisodeRecord, k: int) -> list[int]:
|
||||
"""Return ``k`` evenly spaced row indices into the episode (relative)."""
|
||||
n = record.row_count
|
||||
if k <= 0 or n == 0:
|
||||
return []
|
||||
if k >= n:
|
||||
return list(range(n))
|
||||
step = (n - 1) / (k - 1) if k > 1 else 0.0
|
||||
return [int(round(i * step)) for i in range(k)] if k > 1 else [n // 2]
|
||||
|
||||
|
||||
def lookup_data_path(root: Path, episode_index: int) -> tuple[Path, int, int] | None:
|
||||
"""Find the parquet file containing ``episode_index`` and its slice bounds."""
|
||||
for path in gather_data_paths(root):
|
||||
offsets = episode_offsets_per_path(path)
|
||||
if episode_index in offsets:
|
||||
start, count = offsets[episode_index]
|
||||
return path, start, count
|
||||
return None
|
||||
|
||||
|
||||
def episode_frame_timestamps(root: Path, episode_index: int) -> tuple[Any, list[float]]:
|
||||
"""Return the parquet path and per-frame timestamps for ``episode_index``."""
|
||||
found = lookup_data_path(root, episode_index)
|
||||
if found is None:
|
||||
raise ValueError(f"Episode {episode_index} not found under {root}/data/")
|
||||
path, start, count = found
|
||||
table = pq.read_table(path, columns=["timestamp"])
|
||||
timestamps = table.column("timestamp").to_pylist()[start : start + count]
|
||||
return path, [float(t) for t in timestamps]
|
||||
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Per-episode staging.
|
||||
|
||||
Each module writes its raw output as a JSONL file under
|
||||
``<staging_dir>/episode_{ep:06d}/<module>.jsonl``. The writer reads back this
|
||||
staging tree and partitions rows into the two language columns.
|
||||
|
||||
JSONL is preferred over parquet here because the staging artifact is meant to
|
||||
be human-inspectable, easy to diff between prompt iterations, and trivially
|
||||
appended to. The final dataset format is parquet; staging is just an
|
||||
intermediate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ModuleName = str
|
||||
|
||||
_MODULES: tuple[ModuleName, ...] = (
|
||||
"module_1",
|
||||
"module_2",
|
||||
"module_3",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeStaging:
|
||||
"""Filesystem layout for a single episode's staged module outputs."""
|
||||
|
||||
root: Path
|
||||
episode_index: int
|
||||
|
||||
@property
|
||||
def episode_dir(self) -> Path:
|
||||
return self.root / f"episode_{self.episode_index:06d}"
|
||||
|
||||
def path_for(self, module: ModuleName) -> Path:
|
||||
if module not in _MODULES:
|
||||
raise ValueError(f"Unknown module {module!r}; expected one of {_MODULES}")
|
||||
return self.episode_dir / f"{module}.jsonl"
|
||||
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
path = self.path_for(module)
|
||||
if not path.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
def read_all(self) -> dict[ModuleName, list[dict[str, Any]]]:
|
||||
return {m: self.read(m) for m in _MODULES}
|
||||
|
||||
def has(self, module: ModuleName) -> bool:
|
||||
return self.path_for(module).exists()
|
||||
|
||||
|
||||
def iter_staged_episodes(root: Path) -> Iterator[int]:
|
||||
"""Yield episode indices for which any staging artifact exists."""
|
||||
if not root.exists():
|
||||
return
|
||||
for child in sorted(root.iterdir()):
|
||||
if child.is_dir() and child.name.startswith("episode_"):
|
||||
try:
|
||||
yield int(child.name.removeprefix("episode_"))
|
||||
except ValueError:
|
||||
continue
|
||||
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pre-write validation against staged outputs.
|
||||
|
||||
Runs after Modules 1–3 have all written their per-episode artifacts but
|
||||
*before* the writer rewrites parquet shards. The validator never touches
|
||||
parquet; it only inspects the staging tree and the source frame timestamps
|
||||
exposed by :class:`EpisodeRecord`.
|
||||
|
||||
Checks (per the plan's "Intermediate staging and validation" section):
|
||||
|
||||
- exact timestamp alignment against source frame timestamps
|
||||
- no orphan speech / interjection pairs
|
||||
- plan / memory emission consistency (events have a paired persistent row)
|
||||
- VQA assistant ``content`` is valid JSON (one of bbox / keypoint / count /
|
||||
attribute / spatial)
|
||||
- every row maps to its correct column under :func:`column_for_style`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationReport:
|
||||
"""Outcome of one validation pass across all episodes."""
|
||||
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
episodes_checked: int = 0
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.errors
|
||||
|
||||
def add_error(self, message: str) -> None:
|
||||
self.errors.append(message)
|
||||
|
||||
def add_warning(self, message: str) -> None:
|
||||
self.warnings.append(message)
|
||||
|
||||
def summary(self) -> str:
|
||||
return f"checked={self.episodes_checked} errors={len(self.errors)} warnings={len(self.warnings)}"
|
||||
|
||||
|
||||
VQA_ANSWER_SHAPES: dict[str, set[str]] = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
|
||||
def classify_vqa_answer(payload: Any) -> str | None:
|
||||
"""Best-effort classification of a VQA answer payload to a question type."""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
keys = set(payload.keys())
|
||||
for kind, required in VQA_ANSWER_SHAPES.items():
|
||||
if required.issubset(keys):
|
||||
return kind
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
) -> ValidationReport:
|
||||
report = ValidationReport()
|
||||
for record in records:
|
||||
self._validate_episode(record, staging_dir, report)
|
||||
report.episodes_checked += 1
|
||||
return report
|
||||
|
||||
def _validate_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
staging_dir: Path,
|
||||
report: ValidationReport,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged = staging.read_all()
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
for module_name, rows in staged.items():
|
||||
for row in rows:
|
||||
row = {**row, "_module": module_name}
|
||||
all_rows.append(row)
|
||||
|
||||
frame_ts = set(record.frame_timestamps)
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
|
||||
for row in events:
|
||||
self._check_event_timestamp_alignment(row, frame_ts, report, record.episode_index)
|
||||
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
style = row.get("style")
|
||||
module = row.get("_module")
|
||||
try:
|
||||
target_col = column_for_style(style)
|
||||
except ValueError:
|
||||
report.add_error(f"ep={episode_index} module={module}: unknown style {style!r}")
|
||||
return
|
||||
if module == "module_1" and target_col != LANGUAGE_PERSISTENT:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module=module_1 emitted style {style!r} that routes to {target_col} (must be persistent)"
|
||||
)
|
||||
if module in {"module_2", "module_3"} and target_col != LANGUAGE_EVENTS:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={module} emitted style {style!r} that routes to {target_col} (must be events)"
|
||||
)
|
||||
|
||||
def _check_event_timestamp_alignment(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
frame_ts: set[float],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
report.add_error(f"ep={episode_index}: event row missing timestamp: {row!r}")
|
||||
return
|
||||
if self.timestamp_atol == 0.0:
|
||||
if float(ts) not in frame_ts:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} does not match any source frame timestamp"
|
||||
)
|
||||
else:
|
||||
if not any(abs(float(ts) - f) <= self.timestamp_atol for f in frame_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: event row timestamp {ts!r} not within {self.timestamp_atol}s of any frame"
|
||||
)
|
||||
|
||||
def _check_speech_interjection_pairs(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
speech_ts: dict[float, int] = {}
|
||||
interjection_ts: dict[float, int] = {}
|
||||
for row in events:
|
||||
ts = row.get("timestamp")
|
||||
if ts is None:
|
||||
continue
|
||||
ts_f = float(ts)
|
||||
if row.get("style") is None and row.get("role") == "assistant":
|
||||
speech_ts[ts_f] = speech_ts.get(ts_f, 0) + 1
|
||||
if row.get("style") == "interjection":
|
||||
interjection_ts[ts_f] = interjection_ts.get(ts_f, 0) + 1
|
||||
|
||||
for ts in interjection_ts:
|
||||
if ts not in speech_ts:
|
||||
report.add_error(f"ep={episode_index}: interjection at t={ts} has no paired speech atom")
|
||||
|
||||
def _check_plan_memory_consistency(
|
||||
self,
|
||||
persistent: Sequence[dict[str, Any]],
|
||||
events: Sequence[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
plan_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "plan"})
|
||||
memory_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "memory"})
|
||||
subtask_ts = sorted({float(r["timestamp"]) for r in persistent if r.get("style") == "subtask"})
|
||||
interjection_ts = sorted(
|
||||
{
|
||||
float(r["timestamp"])
|
||||
for r in events
|
||||
if r.get("style") == "interjection" and r.get("timestamp") is not None
|
||||
}
|
||||
)
|
||||
|
||||
if persistent and not plan_ts:
|
||||
report.add_warning(f"ep={episode_index}: persistent rows present but no plan emitted")
|
||||
# every interjection should have a same-timestamp plan refresh
|
||||
for ts in interjection_ts:
|
||||
if ts not in set(plan_ts):
|
||||
report.add_error(
|
||||
f"ep={episode_index}: interjection at t={ts} has no co-timestamped plan update"
|
||||
)
|
||||
# memory should be emitted at subtask boundaries (subset relation)
|
||||
if memory_ts and subtask_ts:
|
||||
mem_set = set(memory_ts)
|
||||
sub_set = set(subtask_ts)
|
||||
stray = sorted(mem_set - sub_set)
|
||||
if stray:
|
||||
report.add_warning(f"ep={episode_index}: memory rows at {stray} not at any subtask boundary")
|
||||
|
||||
def _check_vqa_json(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
for row in events:
|
||||
if row.get("style") != "vqa" or row.get("role") != "assistant":
|
||||
continue
|
||||
content = row.get("content")
|
||||
if content is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant row at t={row.get('timestamp')} has null content"
|
||||
)
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except (TypeError, ValueError) as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant content not valid JSON at t={row.get('timestamp')}: {exc}"
|
||||
)
|
||||
continue
|
||||
shape = classify_vqa_answer(payload)
|
||||
if shape is None:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: VQA assistant payload at t={row.get('timestamp')} does not match any known shape: keys={list(payload) if isinstance(payload, dict) else type(payload).__name__}"
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared Qwen-VL client.
|
||||
|
||||
The pipeline uses a single shared VLM across modules. vLLM is preferred when
|
||||
available (high throughput, JSON-guided decoding); transformers is the
|
||||
fallback. A ``stub`` backend is used for unit tests so fixtures never call
|
||||
into a real model.
|
||||
|
||||
The client speaks one method, :meth:`VlmClient.generate_json`, which:
|
||||
|
||||
- accepts a list of OpenAI/HF-style multimodal messages,
|
||||
- requests JSON output (``json_mode=True`` enables guided decoding when the
|
||||
backend supports it),
|
||||
- batches requests transparently,
|
||||
- and reprompts once on a JSON parse failure with an inline correction
|
||||
message before raising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .config import VlmConfig
|
||||
|
||||
|
||||
class VlmClient(Protocol):
|
||||
"""Protocol every backend must implement."""
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
"""Generate one JSON-decoded response per messages list."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubVlmClient:
|
||||
"""Deterministic stub used in unit tests.
|
||||
|
||||
A test passes a callable that maps the *last user message text* (or, if
|
||||
that is empty, the full message list) to a JSON-serializable response.
|
||||
"""
|
||||
|
||||
responder: Callable[[Sequence[dict[str, Any]]], Any]
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
return [self.responder(list(messages)) for messages in messages_batch]
|
||||
|
||||
|
||||
def _strip_to_json(text: str) -> Any:
|
||||
text = text.strip()
|
||||
# Strip <think>...</think> blocks (Qwen3 Thinking style)
|
||||
while "<think>" in text and "</think>" in text:
|
||||
start = text.find("<think>")
|
||||
end = text.find("</think>", start) + len("</think>")
|
||||
text = (text[:start] + text[end:]).strip()
|
||||
# Strip ```json ... ``` fences from chat-tuned backbones
|
||||
if text.startswith("```"):
|
||||
first = text.find("\n")
|
||||
last = text.rfind("```")
|
||||
if first != -1 and last != -1 and last > first:
|
||||
text = text[first + 1 : last].strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
# Fall back to extracting the first balanced {...} block.
|
||||
obj_text = _extract_first_json_object(text)
|
||||
if obj_text is None:
|
||||
raise json.JSONDecodeError("No JSON object found", text, 0)
|
||||
return json.loads(obj_text)
|
||||
|
||||
|
||||
def _extract_first_json_object(text: str) -> str | None:
|
||||
"""Return the first balanced ``{...}`` substring, ignoring braces in
|
||||
string literals. Returns ``None`` if no balanced block is found."""
|
||||
start = text.find("{")
|
||||
if start < 0:
|
||||
return None
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
ch = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
continue
|
||||
if ch == '"' and not escape:
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
depth += 1
|
||||
elif ch == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start : i + 1]
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GenericTextClient:
|
||||
"""Wraps any text-generation callable in JSON-mode + one-retry semantics."""
|
||||
|
||||
generate_text: Callable[[Sequence[Sequence[dict[str, Any]]], int, float], list[str]]
|
||||
config: VlmConfig
|
||||
|
||||
def generate_json(
|
||||
self,
|
||||
messages_batch: Sequence[Sequence[dict[str, Any]]],
|
||||
*,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
) -> list[Any]:
|
||||
max_tok = max_new_tokens if max_new_tokens is not None else self.config.max_new_tokens
|
||||
temp = temperature if temperature is not None else self.config.temperature
|
||||
raw = self.generate_text(messages_batch, max_tok, temp)
|
||||
out: list[Any] = []
|
||||
for messages, text in zip(messages_batch, raw, strict=True):
|
||||
try:
|
||||
out.append(_strip_to_json(text))
|
||||
continue
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
retry = list(messages) + [
|
||||
{"role": "assistant", "content": text},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Your previous reply was not valid JSON. "
|
||||
"Reply with strictly valid JSON, no prose, no fences."
|
||||
),
|
||||
},
|
||||
]
|
||||
retry_text = self.generate_text([retry], max_tok, temp)[0]
|
||||
try:
|
||||
out.append(_strip_to_json(retry_text))
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
# After retry: log preview and return None instead of crashing
|
||||
# the whole pipeline. Modules treat None as "skip".
|
||||
preview = retry_text.strip().replace("\n", " ")[:200]
|
||||
print(
|
||||
f"[vlm] WARNING: failed to parse JSON after retry; preview: {preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
out.append(None)
|
||||
return out
|
||||
|
||||
|
||||
def make_vlm_client(config: VlmConfig) -> VlmClient:
|
||||
"""Build the shared VLM client per the configured backend.
|
||||
|
||||
For ``stub``, callers should construct :class:`StubVlmClient` directly with
|
||||
a responder callable. ``stub`` here is rejected to make accidental misuse
|
||||
obvious.
|
||||
"""
|
||||
if config.backend == "stub":
|
||||
raise ValueError(
|
||||
"Use StubVlmClient(...) directly for the stub backend; make_vlm_client builds real clients."
|
||||
)
|
||||
if config.backend == "vllm":
|
||||
return _make_vllm_client(config)
|
||||
if config.backend == "transformers":
|
||||
return _make_transformers_client(config)
|
||||
if config.backend == "openai":
|
||||
return _make_openai_client(config)
|
||||
raise ValueError(f"Unknown VLM backend: {config.backend!r}")
|
||||
|
||||
|
||||
def _make_vllm_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"vllm is required for backend='vllm'. Install with `pip install lerobot[annotations]`."
|
||||
) from exc
|
||||
# Workaround for cuDNN 9.x + torch 2.8 conv3d regression that surfaces
|
||||
# as CUDNN_STATUS_NOT_INITIALIZED in Qwen-VL vision-tower patch
|
||||
# embedders. Setting LEROBOT_DISABLE_CUDNN=1 forces native PyTorch
|
||||
# convolution kernels — slower but functional.
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
if _os.environ.get("LEROBOT_DISABLE_CUDNN", "").lower() in {"1", "true", "yes"}:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
_torch.backends.cudnn.enabled = False
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"tensor_parallel_size": config.tensor_parallel_size,
|
||||
"gpu_memory_utilization": config.gpu_memory_utilization,
|
||||
"trust_remote_code": config.trust_remote_code,
|
||||
}
|
||||
if config.max_model_len is not None:
|
||||
llm_kwargs["max_model_len"] = config.max_model_len
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
# ``guided_decoding`` would speed up parsing but its API differs across
|
||||
# vllm releases (dict vs GuidedDecodingParams). The _GenericTextClient
|
||||
# wrapper already has a one-retry JSON-recovery path, so we skip it.
|
||||
params = SamplingParams(max_tokens=max_tok, temperature=temp)
|
||||
# ``llm.chat`` handles chat-template application + multimodal input
|
||||
# extraction (image/video blocks) internally, which ``llm.generate``
|
||||
# does not.
|
||||
outputs = llm.chat([list(m) for m in batch], params)
|
||||
return [o.outputs[0].text for o in outputs]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_transformers_client(config: VlmConfig) -> VlmClient:
|
||||
try:
|
||||
import torch # type: ignore[import-not-found]
|
||||
import transformers # type: ignore[import-not-found]
|
||||
from transformers import AutoProcessor # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError("transformers + torch are required for backend='transformers'.") from exc
|
||||
auto_cls = (
|
||||
getattr(transformers, "AutoModelForImageTextToText", None)
|
||||
or getattr(transformers, "AutoModelForVision2Seq", None)
|
||||
)
|
||||
if auto_cls is None:
|
||||
raise ImportError(
|
||||
"Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this "
|
||||
"transformers version. Install transformers>=4.45 (which has AutoModelForImageTextToText) "
|
||||
"for VL models."
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
config.model_id, trust_remote_code=config.trust_remote_code
|
||||
)
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
use_accelerate = _os.environ.get("LEROBOT_TRANSFORMERS_DEVICE_MAP", "manual") != "manual"
|
||||
# ``device_map='auto'`` triggers a known std::bad_alloc on the Qwen3-VL
|
||||
# post-load dispatch path (the alloc fails in accelerate's hook setup
|
||||
# even with TBs of host RAM). Default to manual: load on CPU with
|
||||
# ``low_cpu_mem_usage=True``, then ``.to("cuda")``. Set
|
||||
# ``LEROBOT_TRANSFORMERS_DEVICE_MAP=auto`` to opt back into the old path.
|
||||
if use_accelerate:
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
model = auto_cls.from_pretrained(
|
||||
config.model_id,
|
||||
torch_dtype=_torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
)
|
||||
model = model.to("cuda")
|
||||
model.eval()
|
||||
|
||||
def _gen(batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float) -> list[str]:
|
||||
outs: list[str] = []
|
||||
for messages in batch:
|
||||
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(text=[text], return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
gen = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tok,
|
||||
temperature=temp,
|
||||
do_sample=temp > 0.0,
|
||||
)
|
||||
decoded = processor.batch_decode(
|
||||
gen[:, inputs["input_ids"].shape[-1] :], skip_special_tokens=True
|
||||
)[0]
|
||||
outs.append(decoded)
|
||||
return outs
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _make_openai_client(config: VlmConfig) -> VlmClient:
|
||||
"""Backend that talks to any OpenAI-compatible server.
|
||||
|
||||
Compatible with ``vllm serve``, ``transformers serve``,
|
||||
``ktransformers serve``, and hosted endpoints. By default the server
|
||||
is expected to be already running. Set ``auto_serve=True`` to have
|
||||
this client spawn one (default: ``transformers serve``), wait until
|
||||
it's ready, and tear it down on process exit.
|
||||
|
||||
Image blocks ``{"type":"image", "image":<PIL.Image>}`` are
|
||||
auto-converted to ``image_url`` data-URLs. Video blocks
|
||||
``{"type":"video", "video":[<PIL>...]}`` are forwarded as
|
||||
multi-frame ``video_url`` items where supported.
|
||||
"""
|
||||
try:
|
||||
from openai import OpenAI # type: ignore[import-not-found]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"openai package is required for backend='openai'. "
|
||||
"Install with `pip install openai`."
|
||||
) from exc
|
||||
|
||||
api_base = config.api_base
|
||||
api_key = config.api_key
|
||||
auto_serve = config.auto_serve
|
||||
api_bases: list[str] = [api_base]
|
||||
|
||||
print(
|
||||
f"[lerobot-annotate] backend=openai model={config.model_id} "
|
||||
f"api_base={api_base} auto_serve={auto_serve}",
|
||||
flush=True,
|
||||
)
|
||||
if auto_serve:
|
||||
if config.parallel_servers > 1:
|
||||
print(
|
||||
f"[lerobot-annotate] spawning {config.parallel_servers} parallel servers",
|
||||
flush=True,
|
||||
)
|
||||
api_bases = _spawn_parallel_inference_servers(config)
|
||||
elif _server_is_up(api_base):
|
||||
print(f"[lerobot-annotate] reusing server already up at {api_base}", flush=True)
|
||||
else:
|
||||
print("[lerobot-annotate] no server reachable; spawning one", flush=True)
|
||||
api_base = _spawn_inference_server(config)
|
||||
api_bases = [api_base]
|
||||
print(f"[lerobot-annotate] server ready at {api_base}", flush=True)
|
||||
|
||||
clients = [OpenAI(base_url=base, api_key=api_key) for base in api_bases]
|
||||
client = clients[0]
|
||||
# round-robin counter for parallel mode
|
||||
rr_counter = {"i": 0}
|
||||
|
||||
# ``mm_processor_kwargs`` is a vllm-specific extra; transformers serve
|
||||
# rejects it with HTTP 422. Send it only when explicitly opted in via
|
||||
# an env var (e.g. ``LEROBOT_OPENAI_SEND_MM_KWARGS=1`` for vllm).
|
||||
send_mm_kwargs = os.environ.get(
|
||||
"LEROBOT_OPENAI_SEND_MM_KWARGS", ""
|
||||
).lower() in {"1", "true", "yes"}
|
||||
|
||||
rr_lock = threading.Lock()
|
||||
|
||||
def _one_call(
|
||||
messages: Sequence[dict[str, Any]], max_tok: int, temp: float
|
||||
) -> str:
|
||||
api_messages, mm_kwargs = _to_openai_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id,
|
||||
"messages": api_messages,
|
||||
"max_tokens": max_tok,
|
||||
"temperature": temp,
|
||||
}
|
||||
extra_body: dict[str, Any] = {}
|
||||
if send_mm_kwargs and mm_kwargs:
|
||||
extra_body["mm_processor_kwargs"] = {**mm_kwargs, "do_sample_frames": True}
|
||||
if config.chat_template_kwargs:
|
||||
extra_body["chat_template_kwargs"] = config.chat_template_kwargs
|
||||
if extra_body:
|
||||
kwargs["extra_body"] = extra_body
|
||||
with rr_lock:
|
||||
chosen = clients[rr_counter["i"] % len(clients)]
|
||||
rr_counter["i"] += 1
|
||||
response = chosen.chat.completions.create(**kwargs)
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
def _gen(
|
||||
batch: Sequence[Sequence[dict[str, Any]]], max_tok: int, temp: float
|
||||
) -> list[str]:
|
||||
if len(batch) <= 1 or config.client_concurrency <= 1:
|
||||
return [_one_call(messages, max_tok, temp) for messages in batch]
|
||||
# Parallel fan-out — vllm batches these on the server side.
|
||||
from concurrent.futures import ThreadPoolExecutor # noqa: PLC0415
|
||||
|
||||
max_workers = min(config.client_concurrency, len(batch))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(_one_call, messages, max_tok, temp) for messages in batch
|
||||
]
|
||||
return [f.result() for f in futures]
|
||||
|
||||
return _GenericTextClient(_gen, config)
|
||||
|
||||
|
||||
def _spawn_parallel_inference_servers(config: VlmConfig) -> list[str]:
|
||||
"""Spawn ``config.parallel_servers`` independent vllm replicas.
|
||||
|
||||
Each replica:
|
||||
- is pinned to a single GPU via ``CUDA_VISIBLE_DEVICES``
|
||||
- listens on ``serve_port + i``
|
||||
- is shut down via the same atexit hook as the single-server path
|
||||
|
||||
Returns the list of ``api_base`` URLs the client should round-robin
|
||||
across.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import os as _os # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
|
||||
n = config.parallel_servers
|
||||
api_bases: list[str] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
ready_events: list[threading.Event] = []
|
||||
# Multiple readiness signals — uvicorn's own banner is suppressed at
|
||||
# ``--uvicorn-log-level warning``, so we also accept vllm's own
|
||||
# "Starting vLLM API server" line and the route-listing line. The
|
||||
# HTTP probe below is the ultimate fallback.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
# Single lock for all server-stream threads so multibyte chars from
|
||||
# different servers don't interleave and tear UTF-8 sequences.
|
||||
print_lock = threading.Lock()
|
||||
|
||||
base_cmd = config.serve_command or (
|
||||
f"vllm serve {shlex.quote(config.model_id)} "
|
||||
f"--tensor-parallel-size 1 "
|
||||
f"--max-model-len {config.max_model_len or 32768} "
|
||||
f"--uvicorn-log-level warning"
|
||||
)
|
||||
|
||||
num_gpus = config.num_gpus if config.num_gpus > 0 else n
|
||||
for i in range(n):
|
||||
port = config.serve_port + i
|
||||
gpu = i % num_gpus
|
||||
env = _os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
||||
cmd = base_cmd
|
||||
if "{port}" in cmd:
|
||||
cmd = cmd.replace("{port}", str(port))
|
||||
else:
|
||||
cmd = f"{cmd} --port {port}"
|
||||
api_base = f"http://localhost:{port}/v1"
|
||||
api_bases.append(api_base)
|
||||
print(f"[server-{i}] launching on GPU {gpu} port {port}: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
env=env,
|
||||
)
|
||||
procs.append(proc)
|
||||
ready = threading.Event()
|
||||
ready_events.append(ready)
|
||||
|
||||
def _stream(idx: int, p: subprocess.Popen, ev: threading.Event) -> None:
|
||||
# Read whole lines and emit each line atomically under the
|
||||
# shared print_lock so output from N servers stays readable.
|
||||
assert p.stdout is not None
|
||||
for line in iter(p.stdout.readline, ""):
|
||||
with print_lock:
|
||||
sys.stdout.write(f"[server-{idx}] {line}")
|
||||
if not line.endswith(("\n", "\r")):
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
if any(m in line for m in ready_markers):
|
||||
ev.set()
|
||||
|
||||
threading.Thread(target=_stream, args=(i, proc, ready), daemon=True).start()
|
||||
|
||||
def _probe(idx: int, base: str, ev: threading.Event, p: subprocess.Popen) -> None:
|
||||
while not ev.is_set() and p.poll() is None:
|
||||
if _server_is_up(base):
|
||||
print(f"[server-{idx}] ready (http probe)", flush=True)
|
||||
ev.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, args=(i, api_base, ready, proc), daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is None:
|
||||
print(f"[server-{i}] stopping pid={p.pid}", flush=True)
|
||||
p.send_signal(signal.SIGINT)
|
||||
for p in procs:
|
||||
try:
|
||||
p.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
p.kill()
|
||||
p.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while any(not ev.is_set() for ev in ready_events) and time.monotonic() < deadline:
|
||||
for i, p in enumerate(procs):
|
||||
if p.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server-{i}] inference server exited unexpectedly with rc={p.returncode}"
|
||||
)
|
||||
time.sleep(2)
|
||||
if any(not ev.is_set() for ev in ready_events):
|
||||
raise RuntimeError(
|
||||
f"[server] not all replicas became ready within {config.serve_ready_timeout_s}s"
|
||||
)
|
||||
print(f"[lerobot-annotate] all {n} servers ready: {api_bases}", flush=True)
|
||||
return api_bases
|
||||
|
||||
|
||||
def _server_is_up(api_base: str) -> bool:
|
||||
"""Return True if ``api_base/models`` answers 200 within 2 seconds."""
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=2) as resp:
|
||||
return resp.status == 200
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
def _spawn_inference_server(config: VlmConfig) -> str:
|
||||
"""Spawn ``transformers serve`` (or ``serve_command``), wait until it
|
||||
accepts ``/v1/models``, and register a shutdown hook.
|
||||
|
||||
Streams the server's stdout/stderr to the parent terminal in
|
||||
real-time on a background thread so users can see model-load
|
||||
progress and errors as they happen.
|
||||
|
||||
Returns the full ``api_base`` URL the OpenAI client should use.
|
||||
"""
|
||||
import atexit # noqa: PLC0415
|
||||
import shlex # noqa: PLC0415
|
||||
import signal # noqa: PLC0415
|
||||
import subprocess # noqa: PLC0415
|
||||
import sys # noqa: PLC0415
|
||||
import threading # noqa: PLC0415
|
||||
import time # noqa: PLC0415
|
||||
import urllib.request # noqa: PLC0415
|
||||
|
||||
cmd = config.serve_command
|
||||
if not cmd:
|
||||
cmd = (
|
||||
f"transformers serve {shlex.quote(config.model_id)} "
|
||||
f"--port {config.serve_port} --continuous-batching"
|
||||
)
|
||||
api_base = f"http://localhost:{config.serve_port}/v1"
|
||||
print(f"[server] launching: {cmd}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
shlex.split(cmd),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Watch the server output for the uvicorn readiness banner. This is
|
||||
# more reliable than polling /v1/models because transformers serve
|
||||
# rescans its cache on every model-list request, which can exceed
|
||||
# the urllib timeout and trigger an infinite probe loop.
|
||||
ready_event = threading.Event()
|
||||
# See _spawn_parallel_inference_servers for why we accept these.
|
||||
ready_markers = (
|
||||
"Uvicorn running",
|
||||
"Application startup complete",
|
||||
"Starting vLLM API server",
|
||||
"Available routes are",
|
||||
)
|
||||
|
||||
def _probe() -> None:
|
||||
while not ready_event.is_set() and proc.poll() is None:
|
||||
if _server_is_up(api_base):
|
||||
print("[server] ready (http probe)", flush=True)
|
||||
ready_event.set()
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
threading.Thread(target=_probe, daemon=True).start()
|
||||
|
||||
def _stream_output() -> None:
|
||||
# Read raw chunks instead of iterating lines so tqdm progress
|
||||
# bars (which overwrite using \r) flush in real time.
|
||||
assert proc.stdout is not None
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
while True:
|
||||
ch = proc.stdout.read(1)
|
||||
if ch == "":
|
||||
# process exited; flush any tail
|
||||
if buf:
|
||||
sys.stdout.write(buf)
|
||||
sys.stdout.flush()
|
||||
return
|
||||
if not prefix_started:
|
||||
sys.stdout.write("[server] ")
|
||||
prefix_started = True
|
||||
sys.stdout.write(ch)
|
||||
sys.stdout.flush()
|
||||
buf += ch
|
||||
if ch in ("\n", "\r"):
|
||||
if any(marker in buf for marker in ready_markers):
|
||||
ready_event.set()
|
||||
buf = ""
|
||||
prefix_started = False
|
||||
|
||||
threading.Thread(target=_stream_output, daemon=True).start()
|
||||
|
||||
def _shutdown() -> None:
|
||||
if proc.poll() is None:
|
||||
print(f"[server] stopping pid={proc.pid}", flush=True)
|
||||
proc.send_signal(signal.SIGINT)
|
||||
try:
|
||||
proc.wait(timeout=15)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
atexit.register(_shutdown)
|
||||
|
||||
deadline = time.monotonic() + config.serve_ready_timeout_s
|
||||
while time.monotonic() < deadline:
|
||||
if proc.poll() is not None:
|
||||
raise RuntimeError(
|
||||
f"[server] inference server exited unexpectedly with rc={proc.returncode}. "
|
||||
f"See [server] log lines above for the cause."
|
||||
)
|
||||
if ready_event.wait(timeout=2):
|
||||
return api_base
|
||||
proc.terminate()
|
||||
raise RuntimeError(
|
||||
f"[server] did not become ready within {config.serve_ready_timeout_s}s"
|
||||
)
|
||||
|
||||
|
||||
def _to_openai_messages(
|
||||
messages: Sequence[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
||||
"""Convert internal messages to OpenAI chat format.
|
||||
|
||||
Returns ``(api_messages, mm_kwargs)``. Multimodal-processor kwargs
|
||||
(``fps`` from ``video_url`` blocks) are extracted out so the caller
|
||||
can pass them via ``extra_body.mm_processor_kwargs`` rather than
|
||||
inside the content blocks (which transformers serve rejects).
|
||||
|
||||
File-URL video blocks are inlined as base64 data URLs.
|
||||
"""
|
||||
out_messages: list[dict[str, Any]] = []
|
||||
mm_kwargs: dict[str, Any] = {}
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
out_messages.append({"role": message["role"], "content": content})
|
||||
continue
|
||||
out_blocks: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
if block_type == "text":
|
||||
out_blocks.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(block["image"])}}
|
||||
)
|
||||
elif block_type == "video":
|
||||
frames = block.get("video", [])
|
||||
for img in frames:
|
||||
out_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": _pil_to_data_url(img)}}
|
||||
)
|
||||
elif block_type == "video_url":
|
||||
video_url = dict(block["video_url"])
|
||||
url = video_url.get("url", "")
|
||||
if url.startswith("file://"):
|
||||
video_url["url"] = _file_to_data_url(url[len("file://") :])
|
||||
out_blocks.append({"type": "video_url", "video_url": video_url})
|
||||
fps = block.get("fps")
|
||||
if fps is not None:
|
||||
mm_kwargs["fps"] = fps
|
||||
else:
|
||||
out_blocks.append(block)
|
||||
out_messages.append({"role": message["role"], "content": out_blocks})
|
||||
return out_messages, mm_kwargs
|
||||
|
||||
|
||||
def _file_to_data_url(path: str) -> str:
|
||||
"""Read a local video file and return a base64 ``data:video/mp4`` URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
|
||||
with open(path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:video/mp4;base64,{b64}"
|
||||
|
||||
|
||||
def _pil_to_data_url(image: Any) -> str:
|
||||
"""Encode a PIL.Image as a base64 data URL."""
|
||||
import base64 # noqa: PLC0415
|
||||
import io # noqa: PLC0415
|
||||
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def _messages_to_prompt(messages: Sequence[dict[str, Any]]) -> Any:
|
||||
"""Pass-through hook used by the vllm backend.
|
||||
|
||||
vllm exposes its own multimodal entry points that vary by version; for the
|
||||
base flow we simply forward the raw message list and let the caller's
|
||||
custom backend handle templating. Real deployments override this.
|
||||
"""
|
||||
return list(messages)
|
||||
@@ -0,0 +1,341 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Final parquet rewrite.
|
||||
|
||||
For every episode the writer:
|
||||
|
||||
1. reads the staged module outputs,
|
||||
2. partitions them into a persistent slice (PERSISTENT_STYLES) and an event
|
||||
slice (EVENT_ONLY_STYLES + style=None tool-call atoms),
|
||||
3. sorts each slice deterministically,
|
||||
4. broadcasts the persistent slice across every frame in the episode,
|
||||
5. for each frame, materializes the sublist of event rows whose timestamp
|
||||
exactly equals that frame's timestamp,
|
||||
6. drops the legacy ``subtask_index`` column,
|
||||
7. writes the parquet shard back in place.
|
||||
|
||||
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
|
||||
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
|
||||
struct (PR 1) for every speech atom. The tool *schema* (the description
|
||||
of the ``say`` function and its parameters) is a fixed code constant —
|
||||
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
|
||||
it directly rather than reading a redundant per-row column.
|
||||
|
||||
Invariants enforced here (and re-checked by the validator):
|
||||
|
||||
- per-episode persistent slice is byte-identical across every frame;
|
||||
- ``language_events`` rows on a frame all have ``timestamp == frame_ts``
|
||||
(timestamps come straight from the source parquet — never recomputed);
|
||||
- every row passes ``column_for_style(style)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
from .staging import EpisodeStaging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Tool schema constants moved to lerobot.datasets.language in PR 1 — single
|
||||
# source of truth. Re-exported here so existing imports
|
||||
# (``from lerobot.annotations.steerable_pipeline.writer import SAY_TOOL_SCHEMA``)
|
||||
# keep working.
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS, SAY_TOOL_SCHEMA # noqa: F401, E402
|
||||
|
||||
|
||||
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
return (float(row["timestamp"]), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the persistent column's struct shape."""
|
||||
style = row.get("style")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
raise ValueError(
|
||||
f"persistent slice contains row with non-persistent style {style!r}; "
|
||||
"row would be misrouted under column_for_style()"
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Coerce a staged row into the event column's struct shape (no timestamp)."""
|
||||
style = row.get("style")
|
||||
if style is not None and style not in EVENT_ONLY_STYLES:
|
||||
raise ValueError(
|
||||
f"event slice contains row with style {style!r}; expected None or one of {EVENT_ONLY_STYLES}"
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tool_calls(value: Any) -> list[Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"tool_calls must be a list or None, got {type(value).__name__}")
|
||||
return list(value)
|
||||
|
||||
|
||||
def _validate_atom_invariants(row: dict[str, Any]) -> None:
|
||||
"""At-least-one of content/tool_calls; style=None implies tool_calls."""
|
||||
has_content = row.get("content") is not None
|
||||
has_tools = row.get("tool_calls") is not None
|
||||
if not (has_content or has_tools):
|
||||
raise ValueError(f"row has neither content nor tool_calls: {row!r}")
|
||||
if row.get("style") is None and not has_tools:
|
||||
raise ValueError(f"style=None requires tool_calls: {row!r}")
|
||||
|
||||
|
||||
def _validate_speech_atom(row: dict[str, Any]) -> None:
|
||||
"""Speech atoms: role=assistant, style=None, content=None, say tool call."""
|
||||
if row.get("style") is not None:
|
||||
return # not a speech atom
|
||||
if row.get("role") != "assistant":
|
||||
raise ValueError(f"speech atom must have role=assistant: {row!r}")
|
||||
if row.get("content") is not None:
|
||||
raise ValueError(f"speech atom must have content=null: {row!r}")
|
||||
tool_calls = row.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
raise ValueError(f"speech atom must have non-empty tool_calls list: {row!r}")
|
||||
first = tool_calls[0]
|
||||
if not isinstance(first, dict):
|
||||
raise ValueError(f"speech atom tool_calls[0] must be a dict: {row!r}")
|
||||
if first.get("type") != "function":
|
||||
raise ValueError(f"speech atom tool_calls[0].type must be 'function': {row!r}")
|
||||
fn = first.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
raise ValueError(f"speech atom tool_calls[0].function.name must be 'say': {row!r}")
|
||||
args = fn.get("arguments") or {}
|
||||
if not isinstance(args, dict) or "text" not in args or not isinstance(args["text"], str):
|
||||
raise ValueError(f"speech atom must carry 'text' string in arguments: {row!r}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageColumnsWriter:
|
||||
"""Rewrite ``data/chunk-*/file-*.parquet`` with the two language columns."""
|
||||
|
||||
drop_existing_subtask_index: bool = True
|
||||
|
||||
def write_all(
|
||||
self,
|
||||
records: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> list[Path]:
|
||||
episodes_by_path: dict[Path, list[EpisodeRecord]] = defaultdict(list)
|
||||
for record in records:
|
||||
episodes_by_path[record.data_path].append(record)
|
||||
|
||||
written: list[Path] = []
|
||||
for path, eps in episodes_by_path.items():
|
||||
self._rewrite_one(path, eps, staging_dir, root)
|
||||
written.append(path)
|
||||
return written
|
||||
|
||||
def _rewrite_one(
|
||||
self,
|
||||
path: Path,
|
||||
episodes: Sequence[EpisodeRecord],
|
||||
staging_dir: Path,
|
||||
root: Path,
|
||||
) -> None:
|
||||
table = pq.read_table(path)
|
||||
n_rows = table.num_rows
|
||||
|
||||
# Ensure we cover every episode in the file. Episodes that don't have
|
||||
# staging artifacts are passed through with empty annotation lists —
|
||||
# this keeps the writer idempotent and safe for partial reruns.
|
||||
staged_per_ep: dict[int, dict[str, list[dict[str, Any]]]] = {}
|
||||
for record in episodes:
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
staged_per_ep[record.episode_index] = staging.read_all()
|
||||
|
||||
persistent_by_ep: dict[int, list[dict[str, Any]]] = {}
|
||||
events_by_ep_ts: dict[int, dict[float, list[dict[str, Any]]]] = {}
|
||||
|
||||
for ep_index, ep_staged in staged_per_ep.items():
|
||||
persistent_rows: list[dict[str, Any]] = []
|
||||
event_rows: list[dict[str, Any]] = [] # carry timestamp until bucketed
|
||||
for _module_name, rows in ep_staged.items():
|
||||
for row in rows:
|
||||
style = row.get("style")
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
persistent_rows.append(row)
|
||||
else:
|
||||
event_rows.append(row)
|
||||
|
||||
persistent_rows.sort(key=_row_persistent_sort_key)
|
||||
normalized_persistent = []
|
||||
for r in persistent_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
normalized_persistent.append(_normalize_persistent_row(r))
|
||||
persistent_by_ep[ep_index] = normalized_persistent
|
||||
|
||||
buckets: dict[float, list[dict[str, Any]]] = defaultdict(list)
|
||||
for r in event_rows:
|
||||
_validate_atom_invariants(r)
|
||||
_validate_speech_atom(r)
|
||||
ts = float(r["timestamp"])
|
||||
buckets[ts].append(_normalize_event_row(r))
|
||||
for ts in list(buckets.keys()):
|
||||
buckets[ts].sort(key=_row_event_sort_key)
|
||||
events_by_ep_ts[ep_index] = buckets
|
||||
|
||||
episode_col = (
|
||||
table.column("episode_index").to_pylist() if "episode_index" in table.column_names else None
|
||||
)
|
||||
ts_col = table.column("timestamp").to_pylist() if "timestamp" in table.column_names else None
|
||||
if episode_col is None or ts_col is None:
|
||||
raise ValueError(f"{path} is missing 'episode_index' or 'timestamp' — required by the writer.")
|
||||
|
||||
per_row_persistent: list[list[dict[str, Any]]] = []
|
||||
per_row_events: list[list[dict[str, Any]]] = []
|
||||
for i in range(n_rows):
|
||||
ep = episode_col[i]
|
||||
ts = float(ts_col[i])
|
||||
per_row_persistent.append(persistent_by_ep.get(ep, []))
|
||||
buckets = events_by_ep_ts.get(ep, {})
|
||||
per_row_events.append(buckets.get(ts, []))
|
||||
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
pq.write_table(new_table, path)
|
||||
|
||||
def _materialize_table(
|
||||
self,
|
||||
table: pa.Table,
|
||||
persistent: list[list[dict[str, Any]]],
|
||||
events: list[list[dict[str, Any]]],
|
||||
*,
|
||||
drop_old: bool,
|
||||
) -> pa.Table:
|
||||
cols = []
|
||||
names = []
|
||||
for name in table.column_names:
|
||||
if drop_old and name == "subtask_index":
|
||||
continue
|
||||
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
|
||||
continue # we'll re-add canonical versions
|
||||
# Strip any legacy ``tools`` column previously emitted by older
|
||||
# writers — the schema no longer uses it (constant lives in
|
||||
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
|
||||
if name == "tools":
|
||||
continue
|
||||
cols.append(table.column(name))
|
||||
names.append(name)
|
||||
|
||||
# We let pyarrow infer struct/list schema rather than passing the
|
||||
# canonical type from `lerobot.datasets.language` directly: that type
|
||||
# uses `pa.json_()` for the `tool_calls` element type, which
|
||||
# `pa.array(..., type=...)` cannot materialize from Python lists on
|
||||
# current pyarrow versions. The inferred schema round-trips through
|
||||
# parquet and `LeRobotDataset` correctly — see PR 1's
|
||||
# `tests/datasets/test_language.py` which exercises the same flow.
|
||||
persistent_arr = pa.array(persistent)
|
||||
events_arr = pa.array(events)
|
||||
|
||||
cols.extend([persistent_arr, events_arr])
|
||||
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
|
||||
|
||||
return pa.Table.from_arrays(cols, names=names)
|
||||
|
||||
|
||||
def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"""Build a canonical speech tool-call atom for the events column."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"arguments": {"text": text},
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_rows_for_writer(
|
||||
rows: Iterable[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Helper used by tests/validators to partition a flat row list into
|
||||
(persistent_rows, event_rows) using ``column_for_style``.
|
||||
"""
|
||||
persistent: list[dict[str, Any]] = []
|
||||
events: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
events.append(row)
|
||||
return persistent, events
|
||||
@@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -41,7 +42,10 @@ __all__ = [
|
||||
# Config classes
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate role, stream, and content after dataclass construction."""
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
if self.stream is not None and self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
"""Construct a :class:`MessageTurn` from a plain dictionary."""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that exactly one of ``messages`` or ``blend`` is set."""
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
"""Construct a :class:`TrainingRecipe` from a nested dictionary."""
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
"""Return the binding names that ``turn`` references via placeholders or attributes."""
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
"""Return the set of ``${name}`` placeholders found anywhere in ``content``."""
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(_PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(_PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
@@ -0,0 +1,74 @@
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
bindings:
|
||||
next_subtask: "nth_next(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
|
||||
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
# VQA is view-dependent: bbox / keypoint / count answers only make sense for
|
||||
# the camera they were grounded against. Each camera gets its own sub-recipe
|
||||
# so the resolver can disambiguate via `camera=...` and the user-turn carries
|
||||
# the matching image block. Adjust the camera keys (and add more sub-recipes)
|
||||
# to match the cameras present on your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.top}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -0,0 +1,107 @@
|
||||
# SmolVLA2 canonical training recipe — Hi Robot / MEM / ECoT blend.
|
||||
#
|
||||
# Same blend shape as pi05_hirobot.yaml. SmolVLA2 differs from Pi0.5 in
|
||||
# how the renderer's output is consumed:
|
||||
#
|
||||
# - SmolVLA2 calls SmolVLM's tokenizer.apply_chat_template(messages,
|
||||
# tools=DEFAULT_TOOLS) on the rendered messages, since SmolVLM is a
|
||||
# chat-pretrained backbone.
|
||||
# - The processor builds a `text_labels` tensor that masks every token
|
||||
# except those belonging to messages whose index is in
|
||||
# `target_message_indices`. Cross-entropy on those positions trains
|
||||
# the LM head.
|
||||
# - `predict_actions = bool(targets_by_stream.get("low_level"))` —
|
||||
# same convention as Pi0.5. ``low_level_execution`` is the only
|
||||
# branch that runs the action expert / flow head.
|
||||
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
# PR3 Hi-Robot v2: supervise the high-level head with the *current*
|
||||
# active subtask, not the *next*. Pi 0.5 / Pi 0.7 both do this: at every
|
||||
# frame the assistant target is "what is the robot doing right now"
|
||||
# grounded in the current image + state + context, so the supervision
|
||||
# target is always a non-empty span string.
|
||||
#
|
||||
# The original target was ``nth_next(style=subtask, offset=1)`` — at
|
||||
# most frames within a single span this resolves to the next-span
|
||||
# string (fine), but on the LAST span of an episode it resolves to
|
||||
# empty/None. The recipe had no ``if_present`` guard on the target,
|
||||
# so the renderer emitted an empty assistant turn and cross-entropy
|
||||
# ended up supervising the chat-template's structural newlines.
|
||||
# Across a dataset annotated this way, the LM head's argmax at
|
||||
# position 0 collapses to ``\n`` whenever no transition is happening
|
||||
# (which is most of the time). At inference: head silently emits
|
||||
# newlines every chunk boundary while the action expert keeps working.
|
||||
#
|
||||
# With ``${subtask}`` (binds to ``active_at(t, style=subtask)``) the
|
||||
# target is the current span's text — always non-empty, scene-
|
||||
# grounded. The runtime detects subtask transitions by comparing the
|
||||
# predicted subtask string to the last known one, the same way Pi 0.5
|
||||
# does. No information loss.
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
# Per-camera VQA sub-recipes (PR 1's view-dependent style routing).
|
||||
# Adjust the camera keys (and add more sub-recipes) to match the
|
||||
# cameras present on your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -34,9 +34,16 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
@@ -45,6 +52,19 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
@@ -53,10 +73,15 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"add_features",
|
||||
@@ -66,6 +91,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -34,7 +34,6 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
@@ -177,7 +176,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -320,6 +318,28 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def tools(self) -> list[dict]:
|
||||
"""OpenAI-style tool schemas declared by this dataset.
|
||||
|
||||
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||
can mutate the result safely. Falls back to
|
||||
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||
``say`` schema) when the dataset doesn't declare any — that way
|
||||
unannotated datasets and chat-template consumers
|
||||
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||
working out of the box.
|
||||
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
declared = self.info.get("tools")
|
||||
if isinstance(declared, list) and declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -635,7 +655,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -126,10 +126,53 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
# Datasets annotated with the PR1 language columns may have been
|
||||
# written without registering those columns in ``meta/info.json``
|
||||
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
||||
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
||||
# shard and graft the column features on so the strict
|
||||
# ``Dataset.from_parquet`` cast doesn't fail with
|
||||
# ``column names don't match``.
|
||||
features = self._extend_features_with_language_columns(features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _extend_features_with_language_columns(
|
||||
self, features: datasets.Features
|
||||
) -> datasets.Features:
|
||||
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
||||
when the underlying parquet shards declare them but the metadata
|
||||
doesn't. No-op when neither column is present or both are
|
||||
already registered.
|
||||
"""
|
||||
# Find any one parquet to peek at; bail if there are none yet
|
||||
# (the dataset will fail later for an unrelated reason and we
|
||||
# want that error to surface as-is).
|
||||
try:
|
||||
sample = next((self.root / "data").glob("*/*.parquet"))
|
||||
except StopIteration:
|
||||
return features
|
||||
|
||||
from pyarrow import parquet as _pq # noqa: PLC0415
|
||||
|
||||
schema_names = set(_pq.read_schema(sample).names)
|
||||
from .language import ( # noqa: PLC0415
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
|
||||
extra: dict[str, object] = {}
|
||||
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
||||
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
||||
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
||||
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
||||
if not extra:
|
||||
return features
|
||||
return datasets.Features({**features, **extra})
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
@@ -295,9 +338,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -22,6 +22,12 @@ from PIL import Image as PILImage
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -45,7 +51,13 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -242,6 +254,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return ""
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in {"language_persistent", "language_events"}:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in [
|
||||
"task",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||
|
||||
CORE_STYLES = {
|
||||
"subtask",
|
||||
"plan",
|
||||
"memory",
|
||||
"motion",
|
||||
"interjection",
|
||||
"vqa",
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
EXTENDED_STYLES = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||
|
||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||
# styles MUST carry a non-null ``camera`` referencing an ``observation.images.*``
|
||||
# feature key. Rows of every other style MUST have ``camera=None``. ``motion``
|
||||
# is intentionally NOT in this set: motion primitives are described in
|
||||
# robot-frame (joint / Cartesian) terms, not pixel space, so they are
|
||||
# camera-agnostic. ``trace`` is the pixel-trajectory event style and IS
|
||||
# view-dependent. The ``camera`` field nevertheless lives on
|
||||
# ``PERSISTENT_ROW_FIELDS`` too so the schema, validator, and resolver
|
||||
# behave symmetrically across the two columns; persistent rows simply
|
||||
# always have ``camera=None`` in practice today.
|
||||
VIEW_DEPENDENT_STYLES = {"vqa", "trace"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
"""Return the Arrow JSON type, falling back to ``string`` on older pyarrow."""
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF feature used for tool-call payloads.
|
||||
|
||||
Older ``datasets`` versions do not expose ``datasets.Json``. The
|
||||
annotation pipeline currently emits the canonical ``say`` tool call
|
||||
shape, so use that explicit struct instead of falling back to a string
|
||||
that cannot cast structured parquet values.
|
||||
"""
|
||||
if hasattr(datasets, "Json"):
|
||||
return datasets.Json()
|
||||
return {
|
||||
"type": datasets.Value("string"),
|
||||
"function": {
|
||||
"name": datasets.Value("string"),
|
||||
"arguments": {"text": datasets.Value("string")},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def is_view_dependent_style(style: str | None) -> bool:
|
||||
"""Return ``True`` if rows of ``style`` must be tagged with a ``camera`` key."""
|
||||
return style in VIEW_DEPENDENT_STYLES
|
||||
|
||||
|
||||
def validate_camera_field(style: str | None, camera: str | None) -> None:
|
||||
"""Enforce the ``camera`` invariant: required iff ``style`` is view-dependent.
|
||||
|
||||
Raises ``ValueError`` if a view-dependent style is missing ``camera`` or if
|
||||
a non-view-dependent style carries one. Pipeline writers and the validator
|
||||
should call this on every emitted row.
|
||||
"""
|
||||
if is_view_dependent_style(style):
|
||||
if not camera:
|
||||
raise ValueError(
|
||||
f"Rows of view-dependent style {style!r} require a non-empty 'camera' "
|
||||
f"field referencing an 'observation.images.*' feature key."
|
||||
)
|
||||
elif camera is not None:
|
||||
raise ValueError(
|
||||
f"Rows of style {style!r} must have camera=None; got camera={camera!r}."
|
||||
)
|
||||
|
||||
|
||||
# --- Tool registry --------------------------------------------------------
|
||||
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||
# fallback when the dataset doesn't declare any). Implementations live
|
||||
# under :mod:`lerobot.tools` (one file per tool); see
|
||||
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||
|
||||
SAY_TOOL_SCHEMA: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "say",
|
||||
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The verbatim text to speak.",
|
||||
}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||
import this constant rather than duplicating the dict."""
|
||||
|
||||
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||
keep working out of the box."""
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
@@ -0,0 +1,593 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
)
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. ``events`` is accepted for resolver-signature
|
||||
uniformity but is not consulted: only persistent styles are valid here.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
matches = [row for row in matches if _timestamp(row) <= t]
|
||||
return _select_latest(
|
||||
matches, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
equals ``t``. For event styles, the ``events`` list is assumed to come from
|
||||
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
|
||||
so all matching event rows are considered emitted at ``t``. ``camera``
|
||||
filters by the row's ``camera`` field — required to disambiguate when
|
||||
multiple view-dependent rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
column = column_for_style(style)
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_one(
|
||||
matches,
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
matches = _matching_rows(
|
||||
events, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
return _select_one(
|
||||
matches,
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_event_sort_key,
|
||||
)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=-offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
resolver_name="nth_prev",
|
||||
)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
camera: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``/``camera``) and returns the
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
resolver_name="nth_next",
|
||||
)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(
|
||||
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
||||
),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow] = (),
|
||||
sample_idx: int = 0,
|
||||
) -> str | None:
|
||||
"""Return the task string for ``sample_idx``.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. Explicit ``task`` override (caller-supplied) wins.
|
||||
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||
deterministically pick one by ``sample_idx`` so each frame of an
|
||||
episode rotates through the available rephrasings across an epoch.
|
||||
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||
and falls back to the canonical task otherwise. Recipes that want
|
||||
the literal canonical task can override the binding.
|
||||
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||
backed by ``meta/tasks.parquet``).
|
||||
"""
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [
|
||||
r
|
||||
for r in persistent
|
||||
if r.get("style") == "task_aug" and r.get("role") == "user"
|
||||
]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(
|
||||
f"task_aug:{sample_idx}".encode(), digest_size=8
|
||||
).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
return str(chosen)
|
||||
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
"""Parse a single binding's resolver expression and dispatch to its function."""
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
resolvers = {
|
||||
"active_at": active_at,
|
||||
"emitted_at": emitted_at,
|
||||
"nth_prev": nth_prev,
|
||||
"nth_next": nth_next,
|
||||
}
|
||||
if name not in resolvers:
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
"""Parse a comma-separated resolver argument list into a kwargs dict."""
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
"""Expand ``recipe.messages`` into rendered chat messages using ``bindings``."""
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Substitute bindings into a string or each string field of multimodal blocks."""
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
"""Replace ``${name}`` placeholders in ``template`` with their bound values."""
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
"""Resolve a single ``${name}`` match to its bound string value."""
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return _PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
"""Sanity-check the rendered output for stream/target alignment."""
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
for idx, stream in enumerate(streams):
|
||||
if stream is None:
|
||||
raise ValueError(f"Rendered message {idx} has no stream.")
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
resolver_name: str,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(resolver_name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_persistent_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{resolver_name} requires a persistent style.")
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
column_for_style(style)
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
"""Return ``rows`` filtered by optional ``style``/``role``/``tool_name``/``camera`` selectors."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
and (camera is None or row.get("camera") == camera)
|
||||
]
|
||||
|
||||
|
||||
def _select_latest(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
|
||||
if not rows:
|
||||
return None
|
||||
rows = sorted(rows, key=_persistent_sort_key)
|
||||
latest_ts = _timestamp(rows[-1])
|
||||
return _select_one(
|
||||
[row for row in rows if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
sort_key: Any,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the selectors are ambiguous."""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1 and role is None and tool_name is None and camera is None:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
|
||||
f"or camera=... to disambiguate."
|
||||
)
|
||||
return sorted(rows, key=sort_key)[0]
|
||||
|
||||
|
||||
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
|
||||
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
|
||||
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
|
||||
return (row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
value = row["timestamp"]
|
||||
return float(value.item() if hasattr(value, "item") else value)
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
"""Return ``True`` if any of the row's tool calls invokes ``tool_name``."""
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
"""Convert pyarrow scalars / mappings into a fresh list of plain dict rows."""
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -83,7 +83,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -238,17 +237,24 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -26,6 +26,7 @@ from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config as SmolVLA2Config
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"SmolVLA2Config",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"WallXConfig",
|
||||
|
||||
@@ -140,6 +140,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -386,6 +394,17 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla2":
|
||||
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
|
||||
# MUST come before the SmolVLAConfig isinstance check below
|
||||
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
|
||||
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
|
||||
|
||||
processors = make_smolvla2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only does
|
||||
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
|
||||
model can train on the full Hi Robot / MEM / ECoT message blend defined in
|
||||
the steerable annotation plan (PR1 + PR2):
|
||||
|
||||
* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss
|
||||
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
|
||||
``user_interjection_response``, ``high_level_subtask``) → CE loss on
|
||||
``lm_head`` over the recipe's target message tokens
|
||||
* mixed sub-recipes → both losses summed (weighted)
|
||||
|
||||
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
|
||||
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
|
||||
sample, else ``False``.
|
||||
|
||||
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
|
||||
the model code stays in one place — only the dual-loss path and the
|
||||
chat-template processor live here.
|
||||
"""
|
||||
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
__all__ = ["SmolVLA2Config"]
|
||||
@@ -0,0 +1,448 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2's chat-template tokenization step.
|
||||
|
||||
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
|
||||
``recipe_path`` is set. Reads the rendered messages produced by
|
||||
``RenderMessagesStep`` (PR 1) and produces:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` —
|
||||
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
|
||||
``tools=meta.tools`` (PR 1's catalog).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
the positions belonging to messages whose index is in
|
||||
``target_message_indices``. The next commit's modeling forward path
|
||||
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``. The
|
||||
modeling forward uses this to gate the flow head.
|
||||
|
||||
Image / video content blocks in the rendered messages are dropped
|
||||
before tokenization — the chat template only handles text, and SmolVLA
|
||||
already passes camera tensors out-of-band via the standard
|
||||
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
|
||||
(``embed_prefix`` puts image embeddings before language embeddings,
|
||||
matching the chat-template-stripped text order).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
|
||||
class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
This is the bridge between the recipe stack (PR 1's
|
||||
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
|
||||
(next commit, which reads ``text_labels`` / ``predict_actions``).
|
||||
Pure-text turns and multi-stream targets are both handled.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||
max_length: int = 2048
|
||||
padding: str = "longest"
|
||||
padding_side: str = "right"
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
# --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up
|
||||
# ``feat/pi05-prompt-dropout``). At training, drop non-target
|
||||
# messages whose content was substituted from the named recipe
|
||||
# binding with the given probability. Forces the model to handle
|
||||
# missing context at inference — directly attacks the memorisation
|
||||
# collapse where ``current_subtask=""`` puts the prompt OOD. All
|
||||
# default to 0.0 (no dropout) so behaviour is identical until
|
||||
# explicitly opted in via the training config.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
# Optional seed for the per-sample RNG. ``None`` ⇒ use
|
||||
# ``sample_idx`` derived from the transition (when present), so
|
||||
# dropout is reproducible across runs but varies per sample.
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Lazy: don't load the tokenizer until the step actually runs,
|
||||
# so unit tests that import the module without transformers
|
||||
# installed still pass.
|
||||
self._tokenizer: Any = None
|
||||
if self.tools is None:
|
||||
# Default: ship the canonical ``say`` schema. Users who set
|
||||
# ``meta.tools`` differently can override via
|
||||
# ``with_tools(meta.tools)``.
|
||||
self.tools = list(DEFAULT_TOOLS)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
|
||||
"""Override the tools catalog rendered into the system prompt."""
|
||||
self.tools = list(tools)
|
||||
return self
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
messages = comp.get("messages")
|
||||
if not messages:
|
||||
# No recipe rendering happened — nothing to do; downstream
|
||||
# falls back to whatever ``task`` is in the transition.
|
||||
return transition
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
# Pull a sample_idx for the dropout RNG. ``index`` is the
|
||||
# canonical per-frame key on ``LeRobotDataset`` samples and
|
||||
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
|
||||
# absent (e.g. inference) we fall back to 0 which is harmless
|
||||
# because the dropout probs are also 0 at inference time.
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(comp.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
sorted(int(i) for i in tgt_indices),
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
)
|
||||
for msg, streams, tgt_indices, s_idx in zip(
|
||||
messages,
|
||||
comp.get("message_streams") or [[] for _ in messages],
|
||||
comp.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(comp.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(comp.get("message_streams") or []),
|
||||
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
]
|
||||
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
target_length = self.max_length if self.padding == "max_length" else max(
|
||||
len(ids) for ids, _, _ in encoded
|
||||
)
|
||||
target_length = min(target_length, self.max_length)
|
||||
|
||||
ids_batch = []
|
||||
attn_batch = []
|
||||
labels_batch = []
|
||||
predict_actions = []
|
||||
for ids, labels, predict_action in encoded:
|
||||
ids = ids[:target_length]
|
||||
labels = labels[:target_length]
|
||||
attn = [1] * len(ids)
|
||||
if len(ids) < target_length:
|
||||
n_pad = target_length - len(ids)
|
||||
ids = ids + [pad_id] * n_pad
|
||||
labels = labels + [-100] * n_pad
|
||||
attn = attn + [0] * n_pad
|
||||
ids_batch.append(ids)
|
||||
attn_batch.append(attn)
|
||||
labels_batch.append(labels)
|
||||
predict_actions.append(predict_action)
|
||||
|
||||
ids_t = torch.tensor(ids_batch, dtype=torch.long)
|
||||
attn_t = torch.tensor(attn_batch, dtype=torch.bool)
|
||||
labels_t = torch.tensor(labels_batch, dtype=torch.long)
|
||||
predict_actions_t = torch.tensor(predict_actions, dtype=torch.bool)
|
||||
|
||||
if not _is_batched_messages(messages):
|
||||
ids_t = ids_t.squeeze(0)
|
||||
attn_t = attn_t.squeeze(0)
|
||||
labels_t = labels_t.squeeze(0)
|
||||
predict_actions_t = predict_actions_t.squeeze(0)
|
||||
|
||||
new_complementary = dict(comp)
|
||||
# Drop the per-recipe sidecar keys; everything downstream needs
|
||||
# is now in the tokenized form.
|
||||
new_complementary.pop("messages", None)
|
||||
new_complementary.pop("message_streams", None)
|
||||
new_complementary.pop("target_message_indices", None)
|
||||
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
|
||||
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
|
||||
# them there — and drop ``task`` so the upstream
|
||||
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
|
||||
# tokenize.
|
||||
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
observation[OBS_LANGUAGE_TOKENS] = ids_t
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
|
||||
new_complementary["text_labels"] = labels_t
|
||||
new_complementary["predict_actions"] = predict_actions_t
|
||||
new_complementary.pop("task", None)
|
||||
|
||||
new_transition = dict(transition)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[list[int], list[int], bool]:
|
||||
# Apply per-component prompt dropout *before* tokenisation, so
|
||||
# the dropped messages don't contribute tokens or label-mask
|
||||
# positions at all. Re-maps ``target_indices`` to account for
|
||||
# removed messages.
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages, target_indices, sample_idx
|
||||
)
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_ids = _as_token_ids(full_ids)
|
||||
|
||||
labels = [-100] * len(full_ids)
|
||||
for tgt in target_indices:
|
||||
prefix_ids = tokenizer.apply_chat_template(
|
||||
text_messages[:tgt],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
full_through_target = tokenizer.apply_chat_template(
|
||||
text_messages[: tgt + 1],
|
||||
tools=self.tools,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors=None,
|
||||
)
|
||||
prefix_ids = _as_token_ids(prefix_ids)
|
||||
full_through_target = _as_token_ids(full_through_target)
|
||||
start = len(prefix_ids)
|
||||
end = min(len(full_through_target), len(full_ids))
|
||||
for pos in range(start, end):
|
||||
labels[pos] = int(full_ids[pos])
|
||||
|
||||
predict_actions = any(
|
||||
i < len(message_streams) and message_streams[i] == "low_level"
|
||||
for i in target_indices
|
||||
)
|
||||
return [int(i) for i in full_ids], labels, predict_actions
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
sample_idx: int | None,
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Probabilistically drop non-target context messages.
|
||||
|
||||
Heuristic content sniffing — matches the prefix strings that
|
||||
``smolvla2_hirobot.yaml``'s recipes use when injecting plan /
|
||||
memory / subtask / interjection content. Anything else is
|
||||
kept unchanged. Target messages are never dropped (we still
|
||||
need their tokens for supervision).
|
||||
|
||||
Returns ``(new_messages, new_target_indices)`` where the
|
||||
indices are re-mapped to point at the same target turns in
|
||||
the trimmed list.
|
||||
"""
|
||||
probs = {
|
||||
"plan": float(self.plan_dropout_prob or 0.0),
|
||||
"memory": float(self.memory_dropout_prob or 0.0),
|
||||
"subtask": float(self.subtask_dropout_prob or 0.0),
|
||||
"interjection": float(self.interjection_dropout_prob or 0.0),
|
||||
}
|
||||
if not any(p > 0.0 for p in probs.values()):
|
||||
return messages, target_indices
|
||||
|
||||
# Deterministic per-sample RNG so dropout is reproducible
|
||||
# across runs (matters for debugging / repro) but varies
|
||||
# frame-to-frame.
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0)
|
||||
rng = random.Random(int(seed_int) & 0xFFFFFFFF)
|
||||
|
||||
target_set = set(target_indices)
|
||||
keep_flags: list[bool] = []
|
||||
for i, msg in enumerate(messages):
|
||||
if i in target_set:
|
||||
keep_flags.append(True)
|
||||
continue
|
||||
kind = _classify_message_for_dropout(msg)
|
||||
if kind and rng.random() < probs.get(kind, 0.0):
|
||||
keep_flags.append(False)
|
||||
else:
|
||||
keep_flags.append(True)
|
||||
|
||||
new_messages = [m for m, keep in zip(messages, keep_flags) if keep]
|
||||
# Re-map target_indices: each old index drops by the count of
|
||||
# falsy flags before it.
|
||||
new_target_indices: list[int] = []
|
||||
for old_idx in target_indices:
|
||||
dropped_before = sum(1 for k in keep_flags[:old_idx] if not k)
|
||||
new_target_indices.append(old_idx - dropped_before)
|
||||
return new_messages, sorted(new_target_indices)
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass-through; this step writes runtime tensors not features."""
|
||||
return features
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_tokenizer(self): # noqa: ANN202
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SmolVLA2ChatTokenizerStep requires transformers. "
|
||||
"`pip install lerobot[transformers-dep]`."
|
||||
) from exc
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
|
||||
self._tokenizer.pad_token = self._tokenizer.eos_token
|
||||
return self._tokenizer
|
||||
|
||||
|
||||
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
|
||||
|
||||
The recipe DSL allows authors to write multimodal content like
|
||||
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
|
||||
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
|
||||
``path``). The actual image tensor flows through SmolVLA's
|
||||
``OBS_IMAGES_*`` channels separately; the chat template only needs
|
||||
the text. So we strip non-text blocks before tokenizing.
|
||||
"""
|
||||
new = dict(message)
|
||||
content = new.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
|
||||
new["content"] = text_parts or [{"type": "text", "text": ""}]
|
||||
elif content is None:
|
||||
new["content"] = [{"type": "text", "text": ""}]
|
||||
else:
|
||||
new["content"] = [{"type": "text", "text": str(content)}]
|
||||
if "tool_calls" in new and not new["tool_calls"]:
|
||||
# Drop empty tool_calls — some templates render them as a
|
||||
# spurious empty marker.
|
||||
new.pop("tool_calls")
|
||||
# ``stream`` and ``target`` were recipe metadata; templates don't
|
||||
# know them and may warn or crash.
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Best-effort classification of which recipe binding contributed
|
||||
to this message, used for per-component dropout.
|
||||
|
||||
The canonical recipe authors plan/memory/subtask injections with
|
||||
distinctive prefix strings in the rendered content. Matching on
|
||||
those prefixes is brittle if a future recipe author uses
|
||||
different wording — but it's also localised to one place and
|
||||
only affects the dropout fraction (never the actual semantics).
|
||||
Returns ``None`` for messages we don't recognise; those are
|
||||
always kept.
|
||||
"""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str):
|
||||
text_parts.append(t)
|
||||
content = "\n".join(text_parts)
|
||||
if not isinstance(content, str):
|
||||
return None
|
||||
head = content.lstrip().lower()
|
||||
if head.startswith("plan:") or head.startswith("previous plan"):
|
||||
return "plan"
|
||||
if head.startswith("memory:") or head.startswith("previous memory"):
|
||||
return "memory"
|
||||
if head.startswith("current subtask") or head.startswith("completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
|
||||
|
||||
def _as_token_ids(value: Any) -> list[int]:
|
||||
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
|
||||
value = value["input_ids"]
|
||||
if hasattr(value, "tolist"):
|
||||
value = value.tolist()
|
||||
if isinstance(value, list) and value and isinstance(value[0], list):
|
||||
value = value[0]
|
||||
return [int(i) for i in value]
|
||||
|
||||
|
||||
# Re-export for tests / introspection
|
||||
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..smolvla.configuration_smolvla import SmolVLAConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(SmolVLAConfig):
|
||||
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only
|
||||
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
|
||||
same model can train on:
|
||||
|
||||
* **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss
|
||||
on the action expert, same as SmolVLA. ``predict_actions=True``.
|
||||
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
|
||||
``user_interjection_response`` / ``high_level_subtask``) — cross-
|
||||
entropy loss on the LM head over the recipe's target message tokens.
|
||||
Skips the flow head entirely. ``predict_actions=False``.
|
||||
* **mixed sub-recipes** — both heads run, losses summed (weighted).
|
||||
|
||||
The split is controlled by ``predict_actions = bool(targets_by_stream
|
||||
.get("low_level"))`` per the Pi0.5 convention in the steerable
|
||||
annotation plan (Section I.7), implemented inside the processor /
|
||||
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
|
||||
|
||||
Compared to ``SmolVLAConfig`` this adds:
|
||||
|
||||
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
|
||||
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
|
||||
task-only path so unannotated datasets still work.
|
||||
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
|
||||
both losses are active in a single sample.
|
||||
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn —
|
||||
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
|
||||
need to undo that for SmolVLA2.
|
||||
- ``train_expert_only=False`` by default, since the VLM body now also
|
||||
participates in text-target gradients.
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/smolvla2_hirobot.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
|
||||
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
|
||||
rendering and fall back to SmolVLA's single-task prompt path
|
||||
(unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = True
|
||||
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
|
||||
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
|
||||
matches its training distribution."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 1.0
|
||||
"""Weight on the action-expert flow-matching term."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
|
||||
preceding norm + last text-model layer that SmolVLA freezes). Must be
|
||||
``True`` for the text head to learn. Setting this to ``False``
|
||||
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
||||
which is occasionally useful for ablations."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# At training, randomly drop non-target context messages whose
|
||||
# content was substituted from the named recipe binding. Forces
|
||||
# the model to handle missing context — directly attacks the
|
||||
# memorisation collapse where a stale or missing plan/memory at
|
||||
# inference puts the prompt out-of-distribution and the LM head
|
||||
# falls back to dominant-mode fragments. All default to 0.0 so
|
||||
# behaviour is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Plan:`` or ``Previous plan``
|
||||
with this probability per sample."""
|
||||
memory_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Memory:`` or ``Previous memory``
|
||||
with this probability per sample."""
|
||||
subtask_dropout_prob: float = 0.0
|
||||
"""Drop messages whose content starts with ``Current subtask`` or
|
||||
``Completed subtask`` with this probability per sample."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through its text path when the
|
||||
# LM head is producing supervised text. Override the SmolVLA
|
||||
# default (`train_expert_only=True`) unless the user explicitly
|
||||
# opts out of text training via `text_loss_weight=0`.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
# The user can still flip this back via CLI; this only
|
||||
# changes the *default* when SmolVLA2 is actually training a
|
||||
# text head.
|
||||
self.train_expert_only = False
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-smolvla2-runtime`` builds an ``SmolVLA2Runtime`` and
|
||||
calls ``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import SmolVLA2Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"SmolVLA2Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the SmolVLA2 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[smolvla2] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmolVLA2Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Pipeline order matters. Both ``HighLevelSubtaskFwd`` and
|
||||
# ``LowLevelForward`` are gated on "action queue is empty" so
|
||||
# the slow LLM call (select_message) doesn't starve dispatch.
|
||||
# If LowLevelForward runs first, it refills the queue and the
|
||||
# high-level step never sees ``queue == 0`` afterwards.
|
||||
#
|
||||
# Order is therefore: high-level steps that read state (subtask,
|
||||
# memory, interjection, vqa) → low-level chunk refresh → action
|
||||
# dispatch → tool dispatch. So on an empty-queue tick the
|
||||
# subtask refreshes first, the new subtask string flows into
|
||||
# the next chunk's prompt, and DispatchAction drains.
|
||||
self.pipeline = [
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
MemoryUpdateFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
UserInterjectionFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[smolvla2] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[smolvla2] runtime stopped", flush=True)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
@@ -0,0 +1,874 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the SmolVLA2 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
doesn't fire, the step is a no-op and the runtime moves on.
|
||||
|
||||
Stream-to-step mapping mirrors the ``smolvla2_hirobot.yaml`` recipe:
|
||||
|
||||
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||
action chunk; trained by
|
||||
``low_level_execution``
|
||||
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||
* ``DispatchAction`` — pops one action per control tick and
|
||||
forwards to the robot
|
||||
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||
next subtask; trained by
|
||||
``high_level_subtask``
|
||||
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||
``memory_update``
|
||||
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||
``user_interjection_response``
|
||||
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||
``ask_vqa_*``
|
||||
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||
the matching ``Tool`` instance
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log, set_if_changed, take_event
|
||||
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step base + runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStep:
|
||||
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||
|
||||
trigger: Trigger
|
||||
|
||||
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.trigger.should_fire(state["_tick"], state):
|
||||
return state
|
||||
return self.run(state) or state
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level (action) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LowLevelForward(InferenceStep):
|
||||
"""Run the policy's action head and produce one action chunk."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Callable ``() -> dict``: returns the current observation batch
|
||||
(already preprocessed). Typically wraps the robot's camera /
|
||||
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
# SmolVLA produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). Every step gets dispatched to the robot;
|
||||
# popping one per dispatch tick is essentially free. Only
|
||||
# generate a new chunk once the previous one has fully
|
||||
# drained — this is the canonical "sense → think → act"
|
||||
# loop. Refreshing while a chunk is still queued causes the
|
||||
# new chunk to "telescope" past the old one (planned from an
|
||||
# observation that's already 25+ steps stale by the time it
|
||||
# starts dispatching).
|
||||
queue = state.setdefault("action_queue", [])
|
||||
if len(queue) > 0:
|
||||
return None
|
||||
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Same prompt construction as before — task + plan + memory,
|
||||
# optional current subtask — then merge into the obs batch.
|
||||
ctx = _control_context_messages(state)
|
||||
if state.get("current_subtask"):
|
||||
ctx = ctx + [{"role": "assistant", "content": state["current_subtask"]}]
|
||||
text_batch = _build_text_batch(self.policy, ctx)
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
|
||||
try:
|
||||
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||
# step so DispatchAction at ctrl_hz can drain them
|
||||
# smoothly until the next refresh.
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"predict_action_chunk failed: %s",
|
||||
exc,
|
||||
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
push_log(
|
||||
state,
|
||||
f" [warn] predict_action_chunk failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return None
|
||||
|
||||
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||
# action executor's batch-squeeze logic works unchanged.
|
||||
if chunk.ndim == 3:
|
||||
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||
elif chunk.ndim == 2:
|
||||
chunk_iter = chunk
|
||||
else:
|
||||
chunk_iter = chunk.unsqueeze(0)
|
||||
for step in chunk_iter:
|
||||
queue.append(step.unsqueeze(0))
|
||||
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchAction(InferenceStep):
|
||||
"""Pop one action per tick and hand it to the robot.
|
||||
|
||||
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||
instead of executed.
|
||||
|
||||
Wall-clock catch-up: the action queue represents an open-loop
|
||||
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
|
||||
When the main loop stalls — e.g. an LLM call for the high-level
|
||||
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
|
||||
once over that whole interval. Naively popping a single entry per
|
||||
fire makes the robot lag further and further behind the planned
|
||||
timeline, and a 50-step chunk would take ~125 s to drain instead
|
||||
of ~1.7 s. Track real elapsed time between dispatches and pop
|
||||
``round(elapsed * hz)`` entries, sending the most recent one. The
|
||||
skipped intermediate joint targets are stale anyway — the dynamixel
|
||||
will smooth toward the latest goal position.
|
||||
"""
|
||||
|
||||
robot_executor: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||
_last_dispatch_t: float | None = field(default=None, init=False)
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
# Reset wall-clock anchor when the queue is empty so the
|
||||
# next chunk doesn't see a huge fake "elapsed" window.
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
now = _time.monotonic()
|
||||
hz = getattr(self.trigger, "hz", 30.0)
|
||||
if self._last_dispatch_t is None or hz <= 0:
|
||||
n_to_pop = 1
|
||||
else:
|
||||
elapsed = now - self._last_dispatch_t
|
||||
# ``max(1, ...)`` so we always pop at least one when the
|
||||
# trigger fires; ``min(len(queue), ...)`` so we don't run
|
||||
# off the end of the chunk.
|
||||
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
|
||||
self._last_dispatch_t = now
|
||||
|
||||
# Drain ``n_to_pop`` stale entries, keep only the latest as the
|
||||
# action actually sent. The intermediate joint targets would
|
||||
# all be ~10–30 ms apart in chunk time — the robot can't track
|
||||
# them individually anyway when the host loop is slow.
|
||||
latest = None
|
||||
for _ in range(n_to_pop):
|
||||
if not queue:
|
||||
break
|
||||
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
|
||||
|
||||
if latest is not None and self.robot_executor is not None:
|
||||
self.robot_executor(latest)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level (text) paths — all use policy.select_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Tokenize a list of chat messages into the batch shape
|
||||
``select_message`` expects.
|
||||
|
||||
Lazy fallback: re-uses the policy's preprocessor by piggy-backing
|
||||
on the chat tokenizer step. Production use should construct the
|
||||
batch from a real observation; here we focus on the *language*
|
||||
path which is independent of camera observations.
|
||||
"""
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(policy.config.vlm_model_name)
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Reuse the *exact* normaliser that the training-time chat
|
||||
# tokenizer step uses (``_strip_lerobot_blocks``). It handles all
|
||||
# the cases the SmolVLM chat template expects:
|
||||
# * ``content: list[block]`` → keep text blocks, drop images
|
||||
# * ``content: None`` → ``[{type: text, text: ""}]``
|
||||
# * ``content: str`` / anything else → ``[{type: text, text: str(content)}]``
|
||||
# Doing it any other way creates a training/inference mismatch in
|
||||
# exactly the prompt shape the model was supervised on. Also
|
||||
# strips ``stream`` / ``target`` recipe metadata.
|
||||
from lerobot.policies.smolvla2.chat_processor_smolvla2 import ( # noqa: PLC0415
|
||||
_strip_lerobot_blocks,
|
||||
)
|
||||
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in prompt_messages]
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
text_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# ``apply_chat_template`` can return any of:
|
||||
# - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers),
|
||||
# - a list[int] / list[list[int]] (when ``return_tensors`` is ignored),
|
||||
# - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask``
|
||||
# (newer transformers, especially via processor.apply_chat_template
|
||||
# forwarding through here).
|
||||
# Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's
|
||||
# attention mask when available so we don't have to re-derive it
|
||||
# from ``pad_token_id`` (which can be ``None`` for SmolVLM).
|
||||
attn: Any = None
|
||||
if hasattr(encoded, "input_ids"):
|
||||
ids = encoded.input_ids
|
||||
attn = getattr(encoded, "attention_mask", None)
|
||||
elif isinstance(encoded, dict) and "input_ids" in encoded:
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
else:
|
||||
ids = encoded
|
||||
if isinstance(ids, list):
|
||||
if ids and isinstance(ids[0], list):
|
||||
ids = ids[0]
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
ids = torch.tensor(ids, dtype=torch.long)
|
||||
if hasattr(ids, "ndim") and ids.ndim == 1:
|
||||
ids = ids.unsqueeze(0)
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
elif isinstance(attn, list):
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
attn = torch.tensor(attn, dtype=torch.long)
|
||||
if attn.ndim == 1:
|
||||
attn = attn.unsqueeze(0)
|
||||
# SmolVLA's ``eager_attention_forward`` does
|
||||
# ``torch.where(attention_mask[..., None, :, :], ...)`` which
|
||||
# requires a *bool* condition tensor; ``BatchEncoding``'s
|
||||
# attention_mask is typically Long (0/1). Cast so the prefix
|
||||
# forward doesn't blow up with ``where expected condition to be a
|
||||
# boolean tensor, but got a tensor with dtype Long``.
|
||||
if attn is not None and hasattr(attn, "dtype"):
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
if attn.dtype != _torch.bool:
|
||||
attn = attn.bool()
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
new = dict(m)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask.
|
||||
|
||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
user: "Current subtask: ${subtask}" (if subtask present)
|
||||
↓ generate ↓
|
||||
assistant: <next subtask>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||
set, the resulting observation is merged into ``select_message``'s
|
||||
batch so text generation runs against real video + state."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
||||
# the action queue is empty (i.e. right before LowLevelForward
|
||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
||||
# and running it every loop iteration starves DispatchAction
|
||||
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
|
||||
# of 30/sec and the robot barely moves. Tying it to the same
|
||||
# "queue empty" condition as the chunk refresh produces a
|
||||
# clean sense → think → act cycle.
|
||||
queue = state.get("action_queue") or []
|
||||
if len(queue) > 0:
|
||||
return None
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
||||
# suppression — matches training. Operator can override via
|
||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
||||
# on the CLI; useful for under-trained checkpoints whose LM
|
||||
# head still favours EOS at position 0 (pre-trained chat
|
||||
# backbone's short-turn prior hasn't been fully overridden
|
||||
# by the fine-tuning supervision yet).
|
||||
msg = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
# repeats. Memorisation collapse looks like "same accepted
|
||||
# subtask N times in a row" or "gibberish_count rising while
|
||||
# current_subtask is stuck". The state panel renders these.
|
||||
state["last_subtask_raw"] = msg or ""
|
||||
# Persistent empty completion is its own failure mode (model
|
||||
# immediately EOS-es from the chat-template generation
|
||||
# prompt) — surface it once every N occurrences so the
|
||||
# operator can distinguish "generation failing silently"
|
||||
# from "generating fine but filter rejecting".
|
||||
if not msg:
|
||||
empties = state.get("subtask_empty_count", 0) + 1
|
||||
state["subtask_empty_count"] = empties
|
||||
if empties == 1 or empties % 5 == 0:
|
||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||
if debug:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||
)
|
||||
else:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"no tokens generated (head EOS-ing before any "
|
||||
"non-special token).",
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
# Bump a counter so the operator can see the model is
|
||||
# struggling without spamming the log every tick. A first
|
||||
# rejection still logs once so the failure is visible.
|
||||
count = state.get("subtask_gibberish_count", 0) + 1
|
||||
state["subtask_gibberish_count"] = count
|
||||
if count == 1 or count % 30 == 0:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if msg:
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
state["subtask_repeat_count"] = 0
|
||||
else:
|
||||
# Same accepted string regenerated — memorisation tell.
|
||||
# Once this counter climbs past a few, you're seeing
|
||||
# the model unable to move past the current subtask
|
||||
# despite the chunk having drained (visual scene may
|
||||
# have changed but the LM is replaying training
|
||||
# tokens).
|
||||
state["subtask_repeat_count"] = (
|
||||
state.get("subtask_repeat_count", 0) + 1
|
||||
)
|
||||
# Silently skip empty completions — common when the model
|
||||
# warms up or generates only EOS; logging it every tick at
|
||||
# ctrl_hz is just noise.
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory.
|
||||
|
||||
Mirrors the ``memory_update`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||
↓ generate ↓
|
||||
assistant: <new memory>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
count = state.get("memory_gibberish_count", 0) + 1
|
||||
state["memory_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||
|
||||
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||
user: "${interjection}" (the new utterance)
|
||||
↓ generate ↓
|
||||
assistant: <plan + <say>...</say>>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
# MPS during warm-up and floods the panel. The user can
|
||||
# re-trigger by typing again.
|
||||
return None
|
||||
if _looks_like_gibberish(out):
|
||||
count = state.get("plan_gibberish_count", 0) + 1
|
||||
state["plan_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
|
||||
)
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text and _looks_like_gibberish(plan_text):
|
||||
plan_text = ""
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
push_log(state, f" speech: {speech_text}")
|
||||
state.setdefault("tool_calls_pending", []).append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||
}
|
||||
)
|
||||
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||
# Mark interjection consumed.
|
||||
state["recent_interjection"] = None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA.
|
||||
|
||||
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||
turn carrying just the VQA question, plus the camera image block
|
||||
in training (we drop the image at inference because the dataset's
|
||||
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||
|
||||
user: <question>
|
||||
↓ generate ↓
|
||||
assistant: <vqa answer>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||
return None
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
# the answer as-is — the VQA panel line lets the user judge.
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchToolCalls(InferenceStep):
|
||||
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
take_event(state, "tool_call_pending")
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
for call in pending:
|
||||
try:
|
||||
fn = (call or {}).get("function") or {}
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
tool = self.tools.get(name)
|
||||
if tool is None:
|
||||
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||
continue
|
||||
tool.call(args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||
state["tool_calls_pending"] = []
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _looks_like_gibberish(text: str) -> bool:
|
||||
"""Heuristically detect generation that's clearly off the rails.
|
||||
|
||||
Memorised models can collapse to dominant-mode outputs when the
|
||||
prompt drifts even slightly from training distribution. Reject:
|
||||
|
||||
* empty / whitespace-only
|
||||
* too few alphabetic characters (mostly punctuation)
|
||||
* a single character repeated past the threshold
|
||||
* starts with ``":"`` and contains no letters
|
||||
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
|
||||
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
|
||||
where the model emits one or two memorised tokens repeatedly)
|
||||
* chat-template fragment leakage (``Assistant:``, ``User:``,
|
||||
``Ass\\n``)
|
||||
|
||||
Real subtasks look like ``"close the gripper to grasp the blue
|
||||
cube"`` — multiple unique alphabetic tokens, no role-marker
|
||||
fragments. Anything materially shorter than that is rejected.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
stripped = text.strip()
|
||||
alpha = sum(1 for c in stripped if c.isalpha())
|
||||
if alpha < max(3, len(stripped) // 8):
|
||||
return True
|
||||
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||
return True
|
||||
# Single repeating char: e.g. ``""""""``.
|
||||
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||
return True
|
||||
# Chat-template fragment leakage — the model emits ``Ass``,
|
||||
# ``Assistant:``, ``User:``, often with extra newlines/colons.
|
||||
# Reject if the cleaned text is mostly role-marker shards.
|
||||
cleaned = stripped.replace("\n", " ").replace(":", " ")
|
||||
for marker in ("Assistant", "User", "Ass "):
|
||||
if marker in cleaned and len(cleaned.split()) < 4:
|
||||
return True
|
||||
# Too few unique alphabetic tokens — model stuck on ``the`` or
|
||||
# similar memorised single-token continuations.
|
||||
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
|
||||
unique_alpha = {t.lower() for t in tokens}
|
||||
if len(unique_alpha) < 3 and len(stripped) < 80:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
include_completed: bool = False,
|
||||
extra_user: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a chat-template-ready prompt from current runtime state.
|
||||
|
||||
Mirrors what ``smolvla2_hirobot.yaml`` renders into ``${task}\nPlan:
|
||||
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
task = state.get("task") or ""
|
||||
parts.append(task)
|
||||
if state.get("current_plan"):
|
||||
parts.append(f"Plan: {state['current_plan']}")
|
||||
if state.get("current_memory"):
|
||||
parts.append(f"Memory: {state['current_memory']}")
|
||||
if include_completed and state.get("current_subtask"):
|
||||
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||
head = "\n".join(parts)
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||
if extra_user:
|
||||
msgs.append({"role": "user", "content": extra_user})
|
||||
return msgs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||
# message layout in ``smolvla2_hirobot.yaml`` so the chat-templated
|
||||
# prompt at inference matches what the model saw during training.
|
||||
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||
# for ad-hoc callers but the four high-level steps now use these.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``high_level_subtask`` recipe layout (v2 — predict current subtask).
|
||||
|
||||
The training-time recipe was changed to supervise the model on the
|
||||
*current* active subtask span at every frame, not the next-span text
|
||||
only at transitions. So the inference-time prompt no longer feeds a
|
||||
"Current subtask: X" user message — that would be circular (we'd be
|
||||
telling the model the answer). The model now decides the subtask
|
||||
purely from the task + plan + memory context plus the visual prefix.
|
||||
|
||||
Transition detection moves into the runtime: when the predicted
|
||||
subtask differs from ``state['current_subtask']``, fire the
|
||||
``subtask_change`` event so memory updates. Same downstream signal
|
||||
as before, just produced by an always-non-empty supervision target.
|
||||
"""
|
||||
head_parts = [state.get("task") or ""]
|
||||
if state.get("current_plan"):
|
||||
head_parts.append(f"Plan: {state['current_plan']}")
|
||||
if state.get("current_memory"):
|
||||
head_parts.append(f"Memory: {state['current_memory']}")
|
||||
return [{"role": "user", "content": "\n".join(head_parts)}]
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``memory_update`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_memory"):
|
||||
msgs.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Previous memory: {state['current_memory']}",
|
||||
}
|
||||
)
|
||||
if state.get("current_subtask"):
|
||||
msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Completed subtask: {state['current_subtask']}",
|
||||
}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``user_interjection_response`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_plan"):
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||
)
|
||||
interjection = state.get("recent_interjection")
|
||||
if interjection:
|
||||
msgs.append({"role": "user", "content": interjection})
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||
return [{"role": "user", "content": question}]
|
||||
|
||||
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
Errors from the provider are logged at debug level and swallowed —
|
||||
text generation still runs (in text-only mode) so a flaky frame
|
||||
source doesn't kill the REPL.
|
||||
"""
|
||||
if provider is None:
|
||||
return None
|
||||
try:
|
||||
return provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_with_policy(
|
||||
policy: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
When ``observation`` carries ``observation.images.*`` and
|
||||
``observation.state``, those are merged into the batch so
|
||||
``select_message`` runs the same VLM prefix the policy was trained
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
|
||||
Failures are surfaced both to the module logger (``warning``) and,
|
||||
when ``state`` is given, to the runtime's user-visible log via
|
||||
:func:`push_log`, so the REPL no longer "looks dead" when
|
||||
something goes wrong inside generation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] policy has no select_message — skipping {label}")
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
try:
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
return policy.select_message(
|
||||
batch,
|
||||
tokenizer=text_batch["tokenizer"],
|
||||
min_new_tokens=min_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
|
||||
return ""
|
||||
|
||||
|
||||
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||
|
||||
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||
deterministic textual marker so prefix-LM-style training learns to
|
||||
emit it. The runtime parses it back here. If no marker is present,
|
||||
the entire text is treated as plan with no speech.
|
||||
"""
|
||||
if not text:
|
||||
return "", ""
|
||||
match = _SAY_RE.search(text)
|
||||
if not match:
|
||||
return text.strip(), ""
|
||||
speech = match.group(1).strip().strip('"').strip("'")
|
||||
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||
return plan, speech
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for SmolVLA2's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
whether the step fires. Two trigger flavours cover all the cadences
|
||||
the canonical recipe needs:
|
||||
|
||||
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||
memory update; user interjection → plan refresh; user VQA query →
|
||||
vqa answer; pending tool call → dispatcher)
|
||||
|
||||
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||
every step shares a single time source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tick:
|
||||
"""Single tick from :class:`TickClock`. Carries time references the
|
||||
runtime steps consume to gate themselves."""
|
||||
|
||||
index: int
|
||||
"""Monotonic counter — increments by one per tick."""
|
||||
|
||||
monotonic_seconds: float
|
||||
"""``time.monotonic()`` at the start of this tick."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickClock:
|
||||
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||
|
||||
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||
"""
|
||||
|
||||
max_rate_hz: float = 50.0
|
||||
_index: int = field(default=0, init=False)
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def advance(self) -> Tick:
|
||||
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||
now = time.monotonic()
|
||||
if self._last_seconds is not None:
|
||||
sleep_for = (self._last_seconds + period) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
now = time.monotonic()
|
||||
self._last_seconds = now
|
||||
self._index += 1
|
||||
return Tick(index=self._index, monotonic_seconds=now)
|
||||
|
||||
|
||||
class Trigger(Protocol):
|
||||
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class HzTrigger:
|
||||
"""Fire at most ``hz`` times per second."""
|
||||
|
||||
hz: float
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
period = 1.0 / max(self.hz, 1e-6)
|
||||
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||
self._last_seconds = tick.monotonic_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventTrigger:
|
||||
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||
|
||||
The runtime fills ``events_this_tick`` once per tick from:
|
||||
|
||||
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||
``stop``)
|
||||
* internal state transitions (``subtask_change``,
|
||||
``tool_call_pending``)
|
||||
|
||||
The list is consumed (cleared at the end of the tick) so events
|
||||
fire at most once.
|
||||
"""
|
||||
|
||||
event_name: str
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
return self.event_name in events
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the SmolVLA2 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||
|
||||
┌── State ──────────────────────────────────────────┐
|
||||
│ task please clean up the kitchen │
|
||||
│ subtask grasp the handle of the sponge │
|
||||
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||
│ memory sponge picked up; counter still dirty │
|
||||
└───────────────────────────────────────────────────┘
|
||||
> _
|
||||
|
||||
The state panel re-renders on every state change. Chat lines are
|
||||
``console.print``'d above the live region so they accumulate naturally
|
||||
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||
panel for cursor position.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try: # rich is optional; only required for the interactive REPL.
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
_HAS_RICH = True
|
||||
except ImportError: # pragma: no cover
|
||||
_HAS_RICH = False
|
||||
Console = Any # type: ignore[assignment]
|
||||
Panel = Any # type: ignore[assignment]
|
||||
Table = Any # type: ignore[assignment]
|
||||
Text = Any # type: ignore[assignment]
|
||||
|
||||
|
||||
_STATE_KEYS = (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
)
|
||||
|
||||
|
||||
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
"""Render the persistent state panel for the live region.
|
||||
|
||||
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||
``Live.update(panel)`` whenever the state changes.
|
||||
"""
|
||||
if not _HAS_RICH:
|
||||
raise RuntimeError(
|
||||
"rich is required for the interactive REPL. "
|
||||
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||
)
|
||||
table = Table.grid(padding=(0, 2), expand=True)
|
||||
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||
table.add_column(justify="left")
|
||||
for key, label in _STATE_KEYS:
|
||||
value = state.get(key)
|
||||
if value is None:
|
||||
rendered = Text("(not set)", style="dim italic")
|
||||
else:
|
||||
rendered = Text(str(value), style="bold")
|
||||
table.add_row(label, rendered)
|
||||
queue = state.get("action_queue")
|
||||
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
footer = Text.assemble(
|
||||
("queued actions: ", "dim"),
|
||||
(str(queue_len), "bold cyan"),
|
||||
(" pending tool calls: ", "dim"),
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
return Panel(table, title="[bold]SmolVLA2 state[/]", border_style="cyan")
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
"""Append a user-typed line to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
print(f"you: {line}", flush=True)
|
||||
return
|
||||
console.print(f"[bold cyan]you:[/] {line}")
|
||||
|
||||
|
||||
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||
"""Append robot/runtime log lines to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
for line in lines:
|
||||
print(f"robot: {line.lstrip()}", flush=True)
|
||||
return
|
||||
for line in lines:
|
||||
# The runtime uses leading whitespace + "label: text"; render
|
||||
# the label in green and the value in default for readability.
|
||||
stripped = line.lstrip()
|
||||
if ":" in stripped:
|
||||
label, _, value = stripped.partition(":")
|
||||
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||
else:
|
||||
console.print(f"[bold green]robot:[/] {stripped}")
|
||||
@@ -0,0 +1,455 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
Adds:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that runs the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``
|
||||
produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on
|
||||
this branch).
|
||||
|
||||
Per-sample routing — within one batch:
|
||||
|
||||
* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow
|
||||
loss (action chunk supervision).
|
||||
* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the
|
||||
flow loss; only its text tokens (where ``text_labels[i, t] != -100``)
|
||||
contribute to the LM-head cross-entropy.
|
||||
|
||||
Falls back to ``SmolVLAPolicy.forward`` cleanly when neither
|
||||
``text_labels`` nor ``predict_actions`` is in the batch — unannotated
|
||||
datasets keep working unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head."""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, **kwargs):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
config = SmolVLA2Config(
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
}
|
||||
)
|
||||
super().__init__(config, **kwargs)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backbone surgery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits
|
||||
of the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
return
|
||||
vlm = getattr(vlm_with_expert, "vlm", None)
|
||||
if vlm is None:
|
||||
return
|
||||
for name, param in vlm.named_parameters():
|
||||
if "lm_head" in name or "text_model.model.norm.weight" in name:
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
noise: Tensor | None = None,
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional dual-head loss.
|
||||
|
||||
Two routing knobs from the batch (produced by
|
||||
:class:`SmolVLA2ChatTokenizerStep`):
|
||||
|
||||
* ``text_labels`` — per-token labels with ``-100`` for non-target
|
||||
positions. Triggers the text-loss path through ``lm_head``.
|
||||
* ``predict_actions`` — per-sample bool tensor. ``True`` ⇒
|
||||
include this sample's action chunk in the flow loss.
|
||||
|
||||
When neither is present, delegate to ``SmolVLAPolicy.forward``.
|
||||
"""
|
||||
text_labels = batch.get("text_labels")
|
||||
predict_actions_t = batch.get("predict_actions")
|
||||
|
||||
has_text_data = (
|
||||
text_labels is not None
|
||||
and isinstance(text_labels, Tensor)
|
||||
and self.config.text_loss_weight > 0
|
||||
)
|
||||
has_per_sample_routing = (
|
||||
predict_actions_t is not None and isinstance(predict_actions_t, Tensor)
|
||||
)
|
||||
|
||||
if not has_text_data and not has_per_sample_routing:
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
|
||||
loss_dict: dict[str, Any] = {}
|
||||
device = batch[OBS_STATE].device
|
||||
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Flow loss path — only when at least one sample wants actions.
|
||||
# ------------------------------------------------------------
|
||||
run_flow = self.config.flow_loss_weight > 0 and (
|
||||
not has_per_sample_routing or bool(predict_actions_t.any().item())
|
||||
)
|
||||
if run_flow and ACTION in batch:
|
||||
per_sample_flow, flow_diag = super().forward(
|
||||
batch, noise=noise, time=time, reduction="none"
|
||||
)
|
||||
# ``per_sample_flow`` has shape (B,) from the SmolVLA
|
||||
# reduction="none" branch.
|
||||
if has_per_sample_routing:
|
||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||
masked = per_sample_flow * mask
|
||||
denom = mask.sum().clamp(min=1.0)
|
||||
flow_loss = masked.sum() / denom
|
||||
else:
|
||||
flow_loss = per_sample_flow.mean()
|
||||
total = total + self.config.flow_loss_weight * flow_loss
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
for k, v in flow_diag.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Text loss path — prefix-only forward → lm_head → CE.
|
||||
# ------------------------------------------------------------
|
||||
if has_text_data:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
|
||||
loss_dict["loss"] = float(total.detach().item())
|
||||
|
||||
if reduction == "none":
|
||||
# Per-sample loss isn't meaningfully defined for the dual
|
||||
# path; broadcast the scalar to (B,) for caller compat.
|
||||
return total.expand(batch[OBS_STATE].shape[0]), loss_dict
|
||||
return total, loss_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Text-loss internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on the SmolVLM ``lm_head`` over target tokens."""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prefix-only forward.
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"SmolVLA2: vlm_with_expert.forward returned no prefix hidden "
|
||||
"states — text-loss path needs them."
|
||||
)
|
||||
|
||||
# Lang token positions inside the prefix. ``embed_prefix`` lays
|
||||
# out the prefix as ``[image_blocks..., lang, state]`` so the
|
||||
# lang range is identifiable from the trailing state size and
|
||||
# the known lang length.
|
||||
num_lang = lang_tokens.shape[1]
|
||||
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||
if num_state < 1:
|
||||
num_state = 1
|
||||
prefix_len = prefix_out.shape[1]
|
||||
lang_end = prefix_len - num_state
|
||||
lang_start = lang_end - num_lang
|
||||
if lang_start < 0 or lang_end > prefix_len:
|
||||
raise RuntimeError(
|
||||
f"SmolVLA2: could not locate lang token range in prefix "
|
||||
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||
f"num_state={num_state})."
|
||||
)
|
||||
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||
logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab)
|
||||
|
||||
if text_labels.shape[1] != num_lang:
|
||||
common = min(text_labels.shape[1], num_lang)
|
||||
logits = logits[:, :common]
|
||||
text_labels = text_labels[:, :common]
|
||||
|
||||
# Standard next-token CE: hidden state at position t predicts
|
||||
# token at position t+1. Shift logits left, labels right by 1.
|
||||
# Without this, the loss is identity-mapped and the LM head
|
||||
# learns nothing useful — see HuggingFace ``LlamaForCausalLM``
|
||||
# for the same convention.
|
||||
shift_logits = logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
valid_labels = shift_labels != -100
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
reduction="sum",
|
||||
)
|
||||
return loss / valid_labels.sum().clamp(min=1)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference: text generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def select_message(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
*,
|
||||
max_new_tokens: int = 256,
|
||||
min_new_tokens: int = 0,
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
tokenizer: Any = None,
|
||||
) -> str:
|
||||
"""Generate text continuation from the chat-templated prompt.
|
||||
|
||||
AR decoding with KV caching reused from SmolVLA's inference
|
||||
path. Batch size is assumed to be 1 (the runtime calls this
|
||||
per-event). Returns the decoded string of new tokens (the
|
||||
prompt itself is not included).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch:
|
||||
Already through the SmolVLA2 preprocessor — expects
|
||||
``OBS_IMAGES_*``, ``OBS_STATE``, ``OBS_LANGUAGE_TOKENS``,
|
||||
``OBS_LANGUAGE_ATTENTION_MASK``.
|
||||
max_new_tokens:
|
||||
Hard cap on generated tokens; stops earlier on EOS.
|
||||
eos_token_id:
|
||||
Override the tokenizer's EOS. ``None`` ⇒ use the
|
||||
tokenizer's default.
|
||||
temperature, top_p:
|
||||
``temperature=0`` does greedy argmax (default — matches
|
||||
training distribution most closely). Set ``temperature>0``
|
||||
with optional ``top_p<1`` for nucleus sampling.
|
||||
tokenizer:
|
||||
Optional pre-loaded tokenizer to avoid the cold-start
|
||||
``AutoTokenizer.from_pretrained`` round-trip on every call.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.config.vlm_model_name)
|
||||
if eos_token_id is None:
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
# Build the full set of special-token ids to suppress during
|
||||
# the ``min_new_tokens`` window. EOS alone is not enough on a
|
||||
# memorised SmolVLM head — when EOS is masked, the argmax
|
||||
# falls onto a sibling special token (``<|im_end|>``,
|
||||
# ``<image>``, ``<fake_token_around_image>``, ``<row_X_col_Y>``,
|
||||
# …) which then survives generation but gets stripped by
|
||||
# ``skip_special_tokens=True`` so ``decode`` returns an empty
|
||||
# string and the runtime sees ``last_raw='(empty)'`` every
|
||||
# chunk boundary.
|
||||
special_ids_set: set[int] = set()
|
||||
try:
|
||||
for sid in (tokenizer.all_special_ids or []):
|
||||
if sid is not None:
|
||||
special_ids_set.add(int(sid))
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if eos_token_id is not None:
|
||||
special_ids_set.add(int(eos_token_id))
|
||||
|
||||
# Match training's text-loss forward path (see
|
||||
# ``_compute_text_loss`` above): build the full prefix via
|
||||
# ``embed_prefix`` so images + state conditioning is intact,
|
||||
# then loop AR with ``fill_kv_cache=True, use_cache=False``.
|
||||
# That flag combo routes every layer through
|
||||
# ``forward_attn_layer`` (which gracefully skips ``None``
|
||||
# expert inputs via ``if hidden_states is None or layer is
|
||||
# None: continue``) and short-circuits the cache-update logic
|
||||
# so we don't have to manage past_kv. Each step just
|
||||
# re-forwards the cumulative ``[prefix + generated]``
|
||||
# sequence.
|
||||
#
|
||||
# This is O(n²) in generated text length but cheap in
|
||||
# absolute terms: image encoding happens once via the initial
|
||||
# ``embed_prefix`` call, and the per-step cost is just one
|
||||
# SmolVLM transformer pass over a sequence that grows by one
|
||||
# token each time. Standard SmolVLM ``generate`` was the
|
||||
# other tempting path, but it can't accept SmolVLA's custom
|
||||
# ``state_proj`` output and its tile-grid expectations
|
||||
# disagree with our preprocessor — both lead to garbage
|
||||
# generation, which is what the prior approach produced.
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
|
||||
device = prefix_embs.device
|
||||
bsize = prefix_embs.shape[0]
|
||||
vlm = self.model.vlm_with_expert.vlm
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
text_emb_scale = math.sqrt(emb_dim)
|
||||
|
||||
current_embs = prefix_embs
|
||||
current_pad = prefix_pad_masks
|
||||
current_att = prefix_att_masks
|
||||
ones_step = torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
|
||||
generated: list[int] = []
|
||||
for _ in range(max_new_tokens):
|
||||
full_2d = make_att_2d_masks(current_pad, current_att)
|
||||
full_pos = torch.cumsum(current_pad, dim=1) - 1
|
||||
|
||||
out_pair, _ = self.model.vlm_with_expert.forward(
|
||||
attention_mask=full_2d,
|
||||
position_ids=full_pos,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair
|
||||
if prefix_out is None:
|
||||
raise RuntimeError(
|
||||
"select_message: vlm_with_expert.forward returned no hidden states."
|
||||
)
|
||||
|
||||
last_hidden = prefix_out[:, -1:].to(vlm.lm_head.weight.dtype)
|
||||
logits_step = vlm.lm_head(last_hidden)[:, -1] # (B, V)
|
||||
# Suppress *all* special tokens until we've decoded
|
||||
# ``min_new_tokens`` real (renderable) tokens. Without
|
||||
# this, a memorised SmolVLM head whose argmax at position
|
||||
# 0 is a special token produces an empty completion every
|
||||
# time — either EOS directly, or (after we mask EOS) the
|
||||
# argmax shifts to a sibling special id (``<|im_end|>``,
|
||||
# ``<image>``, ``<row_X_col_Y>``, …) which decode strips
|
||||
# via ``skip_special_tokens=True``. Masking the full
|
||||
# ``all_special_ids`` set for the first N steps forces
|
||||
# the head to commit to a normal vocabulary token before
|
||||
# it can close (or quietly poison) the turn.
|
||||
if special_ids_set and len(generated) < min_new_tokens:
|
||||
for sid in special_ids_set:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
if eos_token_id is not None and tok_id == eos_token_id:
|
||||
break
|
||||
|
||||
new_emb = self.model.vlm_with_expert.embed_language_tokens(
|
||||
next_ids.unsqueeze(0)
|
||||
)
|
||||
new_emb = new_emb * text_emb_scale
|
||||
current_embs = torch.cat([current_embs, new_emb], dim=1)
|
||||
current_pad = torch.cat([current_pad, ones_step], dim=1)
|
||||
current_att = torch.cat([current_att, ones_step], dim=1)
|
||||
|
||||
decoded = tokenizer.decode(generated, skip_special_tokens=True).strip()
|
||||
# When the visible decoded string is empty but tokens *were*
|
||||
# generated, expose what those raw tokens decoded to without
|
||||
# the special-token filter. This is what the runtime turns
|
||||
# into a scrollback line when ``last_raw='(empty)'`` so the
|
||||
# operator can tell whether the head is emitting EOS, image
|
||||
# placeholder tokens, the chat-template ``<|im_end|>`` shard,
|
||||
# or something else.
|
||||
if not decoded and generated:
|
||||
try:
|
||||
self._last_select_message_debug = (
|
||||
f"raw_ids={generated[:16]} "
|
||||
f"decoded_w_special={tokenizer.decode(generated, skip_special_tokens=False)!r}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
self._last_select_message_debug = f"raw_ids={generated[:16]}"
|
||||
else:
|
||||
self._last_select_message_debug = ""
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(
|
||||
logits: Tensor, temperature: float, top_p: float
|
||||
) -> Tensor:
|
||||
"""Pick one token id per batch row from ``logits``."""
|
||||
if temperature <= 0.0:
|
||||
return logits.argmax(dim=-1)
|
||||
scaled = logits / max(temperature, 1e-6)
|
||||
probs = F.softmax(scaled, dim=-1)
|
||||
if top_p < 1.0:
|
||||
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
|
||||
cum = sorted_probs.cumsum(dim=-1)
|
||||
mask = cum > top_p
|
||||
# Always keep the most-likely token.
|
||||
mask[..., 0] = False
|
||||
sorted_probs = sorted_probs.masked_fill(mask, 0.0)
|
||||
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||||
pick = torch.multinomial(sorted_probs, num_samples=1)
|
||||
return sorted_idx.gather(-1, pick).squeeze(-1)
|
||||
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
@@ -0,0 +1,134 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 processor pipelines.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
RenderMessagesStep(recipe) # PR 1: language_* → messages
|
||||
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
|
||||
DeviceProcessorStep
|
||||
NormalizerProcessorStep
|
||||
|
||||
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
|
||||
plain task-string pipeline so unannotated datasets still work.
|
||||
|
||||
Post-processor is unchanged from SmolVLA.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
RenderMessagesStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def make_smolvla2_pre_post_processors(
|
||||
config: SmolVLA2Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build SmolVLA2's pre/post-processor pipelines.
|
||||
|
||||
With ``recipe_path`` set, inserts the recipe-rendering step and the
|
||||
chat-template tokenizer that emits ``text_labels`` and
|
||||
``predict_actions`` for the dual-loss path. Without it, falls back
|
||||
to SmolVLA's plain task-string pipeline so unannotated datasets
|
||||
keep working unchanged.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
SmolVLA2ChatTokenizerStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding=config.pad_language_to,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/`` so recipe authors can write
|
||||
``--policy.recipe_path=recipes/smolvla2_hirobot.yaml``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -93,6 +93,7 @@ from .relative_action_processor import (
|
||||
to_relative_actions,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .render_messages_processor import RenderMessagesStep
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
@@ -128,6 +129,7 @@ __all__ = [
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"RenderMessagesStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NewLineTaskProcessorStep",
|
||||
|
||||
@@ -174,6 +174,21 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -167,12 +167,35 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
||||
language_persistent_key = (
|
||||
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
||||
)
|
||||
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
||||
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
||||
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
||||
target_message_indices_key = (
|
||||
{"target_message_indices": batch["target_message_indices"]}
|
||||
if "target_message_indices" in batch
|
||||
else {}
|
||||
)
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {
|
||||
**pad_keys,
|
||||
**task_key,
|
||||
**index_key,
|
||||
**task_index_key,
|
||||
**episode_index_key,
|
||||
**timestamp_key,
|
||||
**language_persistent_key,
|
||||
**language_events_key,
|
||||
**messages_key,
|
||||
**message_streams_key,
|
||||
**target_message_indices_key,
|
||||
}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=_scalar(timestamp),
|
||||
sample_idx=int(_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def _call_batch(
|
||||
self,
|
||||
transition: EnvTransition,
|
||||
complementary_data: dict[str, Any],
|
||||
persistent_batch: list,
|
||||
events_batch: list,
|
||||
) -> EnvTransition | None:
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
batch_size = max(len(persistent_batch), len(events_batch))
|
||||
messages: list[list[dict[str, Any]]] = []
|
||||
message_streams: list[list[str | None]] = []
|
||||
target_message_indices: list[list[int]] = []
|
||||
keep_indices: list[int] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
|
||||
events=events_batch[i] if i < len(events_batch) else [],
|
||||
t=_batch_value(timestamp, i),
|
||||
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
|
||||
task=_batch_value(complementary_data.get("task"), i),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_transition = (
|
||||
_select_batch_indices(transition, keep_indices)
|
||||
if len(keep_indices) != batch_size
|
||||
else transition.copy()
|
||||
)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data["messages"] = messages
|
||||
new_complementary_data["message_streams"] = message_streams
|
||||
new_complementary_data["target_message_indices"] = target_message_indices
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def _is_batched_language(value: Any) -> bool:
|
||||
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
|
||||
|
||||
|
||||
def _batch_value(value: Any, index: int) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return value[index]
|
||||
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
|
||||
return _scalar(value[index])
|
||||
return _scalar(value)
|
||||
|
||||
|
||||
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
|
||||
selected = transition.copy()
|
||||
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
|
||||
data = selected.get(key)
|
||||
if isinstance(data, dict):
|
||||
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
|
||||
action = selected.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
selected[TransitionKey.ACTION] = _select_value(action, indices)
|
||||
return selected
|
||||
|
||||
|
||||
def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if isinstance(value, list) and len(value) >= len(indices):
|
||||
return [value[i] for i in indices]
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
@@ -21,6 +21,8 @@ from lerobot.utils.import_utils import make_device_from_device_class
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
@@ -110,7 +112,7 @@ def ensure_safe_goal_position(
|
||||
}
|
||||
|
||||
if warnings_dict:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f"{pformat(warnings_dict, indent=4)}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-annotate`` — populate ``language_persistent`` and
|
||||
``language_events`` columns on a LeRobot dataset.
|
||||
|
||||
Annotations live directly in ``data/chunk-*/file-*.parquet``: there is no
|
||||
flavor namespace and no sidecar tree. Multiple revisions of the same dataset
|
||||
mean multiple dataset copies.
|
||||
|
||||
Example:
|
||||
|
||||
uv run lerobot-annotate \\
|
||||
--root=/path/to/dataset \\
|
||||
--vlm.backend=transformers \\
|
||||
--vlm.model_id=Qwen/Qwen2.5-VL-7B-Instruct
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.frames import make_frame_provider
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import make_vlm_client
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs import parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_root(cfg: AnnotationPipelineConfig) -> Path:
|
||||
if cfg.root is not None:
|
||||
return Path(cfg.root)
|
||||
if cfg.repo_id is not None:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
return Path(snapshot_download(repo_id=cfg.repo_id, repo_type="dataset"))
|
||||
raise ValueError("Either --root or --repo_id must be provided.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Run the steerable annotation pipeline against a dataset."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
root = _resolve_root(cfg)
|
||||
logger.info("annotate: root=%s", root)
|
||||
|
||||
vlm = make_vlm_client(cfg.vlm)
|
||||
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key)
|
||||
# Surface the resolved cameras up front so silent Module-3-no-op
|
||||
# regressions are obvious in job output rather than discovered post-hoc
|
||||
# by counting parquet rows.
|
||||
cam_keys = list(getattr(frame_provider, "camera_keys", []) or [])
|
||||
logger.info(
|
||||
"annotate: frame_provider default camera=%r, all cameras=%s",
|
||||
getattr(frame_provider, "camera_key", None),
|
||||
cam_keys,
|
||||
)
|
||||
if cfg.module_3.enabled and not cam_keys:
|
||||
logger.warning(
|
||||
"annotate: Module 3 (VQA) is enabled but no cameras were "
|
||||
"resolved — Module 3 will produce zero VQA rows. Check "
|
||||
"meta/info.json for observation.images.* features, or pass "
|
||||
"--vlm.camera_key=<key> to seed the cameras list."
|
||||
)
|
||||
module_1 = PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1, frame_provider=frame_provider)
|
||||
module_2 = InterjectionsAndSpeechModule(
|
||||
vlm=vlm, config=cfg.module_2, seed=cfg.seed, frame_provider=frame_provider
|
||||
)
|
||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
module_1=module_1,
|
||||
module_2=module_2,
|
||||
module_3=module_3,
|
||||
writer=writer,
|
||||
validator=validator,
|
||||
)
|
||||
summary = executor.run(root)
|
||||
logger.info("annotate: wrote %d shard(s)", len(summary.written_paths))
|
||||
for phase in summary.phases:
|
||||
logger.info(
|
||||
"annotate: phase=%s processed=%d skipped=%d",
|
||||
phase.name,
|
||||
phase.episodes_processed,
|
||||
phase.episodes_skipped,
|
||||
)
|
||||
if summary.validation_report.warnings:
|
||||
for w in summary.validation_report.warnings:
|
||||
logger.warning(w)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
_push_to_hub(root, cfg)
|
||||
|
||||
|
||||
def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
|
||||
"""Upload the annotated dataset directory to the Hugging Face Hub."""
|
||||
from huggingface_hub import HfApi # noqa: PLC0415
|
||||
|
||||
repo_id = cfg.push_to_hub
|
||||
commit_message = cfg.push_commit_message or "Add steerable annotations (lerobot-annotate)"
|
||||
api = HfApi()
|
||||
print(f"[lerobot-annotate] creating/locating dataset repo {repo_id}...", flush=True)
|
||||
api.create_repo(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
private=cfg.push_private,
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] uploading {root} -> {repo_id}...", flush=True)
|
||||
api.upload_folder(
|
||||
folder_path=str(root),
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
ignore_patterns=[".annotate_staging/**", "**/.DS_Store"],
|
||||
)
|
||||
print(f"[lerobot-annotate] uploaded to https://huggingface.co/datasets/{repo_id}", flush=True)
|
||||
|
||||
# Tag the upload with the codebase version. ``LeRobotDatasetMetadata``
|
||||
# resolves the dataset revision via ``get_safe_version`` which scans
|
||||
# for tags like ``v3.0``; without a tag it raises
|
||||
# ``RevisionNotFoundError``. Read the version straight from the
|
||||
# dataset's own ``meta/info.json`` so we tag whatever the writer
|
||||
# actually wrote (no accidental drift if the codebase floor moves).
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION # noqa: PLC0415
|
||||
|
||||
info_path = root / "meta" / "info.json"
|
||||
version_tag = CODEBASE_VERSION
|
||||
if info_path.exists():
|
||||
try:
|
||||
from lerobot.utils.io_utils import load_json # noqa: PLC0415
|
||||
|
||||
info = load_json(info_path)
|
||||
ds_version = info.get("codebase_version")
|
||||
if isinstance(ds_version, str) and ds_version.startswith("v"):
|
||||
version_tag = ds_version
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True)
|
||||
try:
|
||||
api.create_tag(
|
||||
repo_id=repo_id,
|
||||
tag=version_tag,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
print(f"[lerobot-annotate] tagged {repo_id} as {version_tag}", flush=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(
|
||||
f"[lerobot-annotate] WARNING: could not create tag {version_tag!r} on {repo_id}: {exc}. "
|
||||
"Dataset is uploaded but ``LeRobotDataset`` won't be able to load it until it's tagged. "
|
||||
"Run: from huggingface_hub import HfApi; "
|
||||
f"HfApi().create_tag({repo_id!r}, tag={version_tag!r}, repo_type='dataset', exist_ok=True)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
annotate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=lerobot_collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""LeRobot tool implementations.
|
||||
|
||||
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||
implementations one file per tool, plus the registry that maps tool
|
||||
names to classes.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
|
||||
from .base import Tool
|
||||
from .registry import TOOL_REGISTRY, get_tools
|
||||
from .say import SayTool
|
||||
|
||||
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||
|
||||
Tools are the executable side of the OpenAI-style function-calling
|
||||
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||
messages: the schema describes *what can be called*, the tool
|
||||
implementation describes *how to call it*.
|
||||
|
||||
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||
``say.py`` for ``SayTool``) and are registered in
|
||||
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||
heavy dependencies (torch models, audio backends, network clients,
|
||||
hardware drivers) only load when the dataset actually declares the tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
"""Minimum surface every tool must expose."""
|
||||
|
||||
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||
name: str
|
||||
|
||||
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||
#: prompt.
|
||||
schema: dict[str, Any]
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute the tool with the model-provided arguments.
|
||||
|
||||
``arguments`` is the parsed dict from
|
||||
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||
when the model emits a JSON-string by the chat-template
|
||||
convention). Implementations validate the dict against their own
|
||||
schema; the runtime only routes by name.
|
||||
|
||||
Return value is implementation-defined — typically a tensor
|
||||
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||
``None`` (side-effect-only call).
|
||||
"""
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool registry — name → implementation class.
|
||||
|
||||
Adding a new tool:
|
||||
|
||||
1. Drop a file under ``src/lerobot/tools/`` that defines a class
|
||||
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
|
||||
``schema``, ``call(arguments)``).
|
||||
2. Register the class here under :data:`TOOL_REGISTRY`.
|
||||
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
|
||||
to advertise the schema to the chat-template + policy. The PR 2
|
||||
annotation pipeline preserves anything you put there.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the full authoring guide.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import Tool
|
||||
from .say import SayTool
|
||||
|
||||
#: Map from ``function.name`` to a class implementing :class:`Tool`.
|
||||
#: The runtime instantiates entries lazily — registering a tool here is
|
||||
#: essentially free (no model load happens until ``call`` runs).
|
||||
TOOL_REGISTRY: dict[str, type] = {
|
||||
"say": SayTool,
|
||||
}
|
||||
|
||||
|
||||
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
|
||||
"""Build name → tool-instance dict from a dataset's declared catalog.
|
||||
|
||||
``meta`` is anything with a ``.tools`` attribute returning the
|
||||
OpenAI-style schema list — typically a
|
||||
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
|
||||
Each entry whose ``function.name`` is registered here is
|
||||
instantiated with the schema dict; tools whose name is unknown to
|
||||
the registry are skipped (the schema still rides through the chat
|
||||
template, the model just can't actually invoke that tool at
|
||||
inference).
|
||||
|
||||
Extra keyword arguments are forwarded to every constructor — useful
|
||||
for runtime defaults like ``output_dir=Path("./tts_log")``.
|
||||
"""
|
||||
declared = list(meta.tools)
|
||||
instances: dict[str, Tool] = {}
|
||||
for schema in declared:
|
||||
try:
|
||||
name = schema["function"]["name"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
cls = TOOL_REGISTRY.get(name)
|
||||
if cls is None:
|
||||
continue
|
||||
instances[name] = cls(schema=schema, **kwargs)
|
||||
return instances
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||
|
||||
The first concrete tool implementation. SmolVLA2 (PR 3) and downstream
|
||||
runtime dispatchers consume this when the model emits an assistant
|
||||
message with ``tool_calls=[{function: {name: "say", arguments:
|
||||
{text: ...}}}]``.
|
||||
|
||||
Why pocket-tts:
|
||||
|
||||
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
|
||||
- ~100M parameters, ~200ms first-chunk latency
|
||||
- streamable, voice-cloneable
|
||||
- pip-installable, MIT-style permissive license
|
||||
|
||||
The pocket-tts model is loaded **lazily** the first time ``call(...)``
|
||||
runs (or eagerly via ``preload()``). Loading takes a few seconds and
|
||||
several hundred MB of RAM, so we don't pay the cost when the tool is
|
||||
merely *registered* — only when it's *invoked*.
|
||||
|
||||
Optional dependency. Install with::
|
||||
|
||||
pip install lerobot[tools]
|
||||
# or directly:
|
||||
pip install pocket-tts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SayTool:
|
||||
"""Speak a short utterance via Kyutai's pocket-tts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema:
|
||||
Optional schema override; defaults to the canonical
|
||||
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
|
||||
argument shapes can pass in a modified schema, but the
|
||||
implementation only reads ``arguments["text"]``.
|
||||
voice:
|
||||
One of the pocket-tts catalog voices (``alba``, ``marius``,
|
||||
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
|
||||
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
|
||||
file for cloning. See the pocket-tts model card for licensing.
|
||||
output_dir:
|
||||
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
|
||||
file there in addition to returning the PCM tensor.
|
||||
``None`` (default) skips disk writes — useful for live
|
||||
playback paths that hand the tensor directly to a sounddevice
|
||||
/ WebAudio sink.
|
||||
"""
|
||||
|
||||
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
|
||||
voice: str = "alba"
|
||||
output_dir: Path | None = None
|
||||
|
||||
name: str = field(init=False, default="say")
|
||||
_model: Any = field(init=False, default=None, repr=False)
|
||||
_voice_state: Any = field(init=False, default=None, repr=False)
|
||||
_sample_rate: int = field(init=False, default=24000, repr=False)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lazy model load
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Load the pocket-tts model + voice state into memory.
|
||||
|
||||
Optional — ``call(...)`` triggers this automatically on first
|
||||
invocation. Useful when you want the multi-second load to
|
||||
happen at startup rather than on the first ``say`` the policy
|
||||
emits.
|
||||
"""
|
||||
if self._model is not None and self._voice_state is not None:
|
||||
return
|
||||
try:
|
||||
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
|
||||
except ImportError as exc: # pragma: no cover (env-dependent)
|
||||
raise ImportError(
|
||||
"SayTool requires pocket-tts. Install with `pip install "
|
||||
"lerobot[tools]` or `pip install pocket-tts`."
|
||||
) from exc
|
||||
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
|
||||
self._model = TTSModel.load_model()
|
||||
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
||||
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Speak ``arguments["text"]`` and return the PCM tensor.
|
||||
|
||||
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
|
||||
``self.output_dir`` is set. The returned tensor is a 1-D
|
||||
``torch.Tensor`` of float32 PCM samples at
|
||||
``self.sample_rate`` Hz — directly playable by
|
||||
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
|
||||
encodable by ``scipy.io.wavfile.write``.
|
||||
"""
|
||||
text = arguments.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError(
|
||||
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
|
||||
)
|
||||
self.preload()
|
||||
|
||||
audio = self._model.generate_audio(self._voice_state, text)
|
||||
|
||||
if self.output_dir is not None:
|
||||
self._write_wav(audio, text)
|
||||
|
||||
return audio
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""PCM sample rate of the returned tensor (Hz)."""
|
||||
return self._sample_rate
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _write_wav(self, audio: Any, text: str) -> Path:
|
||||
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
try:
|
||||
import scipy.io.wavfile # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SayTool.output_dir requires scipy. `pip install scipy`."
|
||||
) from exc
|
||||
|
||||
out_dir = Path(self.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# One file per call; suffix with a millisecond timestamp + a
|
||||
# short text snippet so a directory listing is informative.
|
||||
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
|
||||
ts_ms = int(_time.time() * 1000)
|
||||
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
|
||||
|
||||
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
|
||||
# ``.numpy()`` is safe.
|
||||
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
|
||||
return path
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
"""Collate function that preserves Python-list and language fields as lists.
|
||||
|
||||
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
||||
rendered-message and language fields as plain Python lists, and delegates
|
||||
every other key to PyTorch's ``default_collate``.
|
||||
"""
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
preserved = {
|
||||
key: [sample[key] for sample in batch if key in sample]
|
||||
for key in _PYTHON_LIST_KEYS
|
||||
if any(key in sample for sample in batch)
|
||||
}
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
for key, value in sample.items()
|
||||
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
||||
}
|
||||
for sample in batch
|
||||
]
|
||||
collated = default_collate(tensorizable)
|
||||
collated.update(preserved)
|
||||
return collated
|
||||
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helpers shared across annotation-pipeline tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
|
||||
def make_canned_responder(
|
||||
responses_by_marker: dict[str, Any],
|
||||
default: Any = None,
|
||||
) -> StubVlmClient:
|
||||
"""Return a stub that picks a response by inspecting the user prompt.
|
||||
|
||||
For each call the responder examines the last user-message text and
|
||||
returns the response keyed by the first marker substring it contains.
|
||||
Falls back to ``default`` if no marker matches.
|
||||
"""
|
||||
|
||||
def responder(messages: list[dict[str, Any]]) -> Any:
|
||||
last_user_text = ""
|
||||
for message in messages:
|
||||
if message.get("role") != "user":
|
||||
continue
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
last_user_text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
last_user_text = block.get("text", "")
|
||||
for marker, response in responses_by_marker.items():
|
||||
if marker in last_user_text:
|
||||
return response
|
||||
return default
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def encode_vqa_answer(payload: dict[str, Any]) -> str:
|
||||
return json.dumps(payload, sort_keys=True)
|
||||
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Shared fixtures for annotation-pipeline tests.
|
||||
|
||||
Builds a minimal LeRobot-shaped dataset on disk so writer/validator tests
|
||||
can exercise real parquet reads and writes without needing a checked-in
|
||||
LFS dataset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_episode_table(
|
||||
episode_index: int,
|
||||
num_frames: int,
|
||||
*,
|
||||
fps: int = 10,
|
||||
task_index: int = 0,
|
||||
) -> pa.Table:
|
||||
timestamps = [round(i / fps, 6) for i in range(num_frames)]
|
||||
frame_indices = list(range(num_frames))
|
||||
return pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [episode_index] * num_frames,
|
||||
"frame_index": frame_indices,
|
||||
"timestamp": timestamps,
|
||||
"task_index": [task_index] * num_frames,
|
||||
"subtask_index": [0] * num_frames, # legacy column the writer must drop
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_dataset(root: Path, episode_specs: list[tuple[int, int, str]], *, fps: int = 10) -> Path:
|
||||
"""Create a fixture dataset under ``root``.
|
||||
|
||||
``episode_specs`` is a list of ``(episode_index, num_frames, task_text)``.
|
||||
Each episode goes into its own ``data/chunk-000/file-{ep:03d}.parquet``
|
||||
so the writer's per-shard rewrite path is exercised.
|
||||
"""
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks = {}
|
||||
for episode_index, num_frames, task_text in episode_specs:
|
||||
task_index = len(tasks)
|
||||
if task_text not in tasks.values():
|
||||
tasks[task_index] = task_text
|
||||
else:
|
||||
task_index = next(k for k, v in tasks.items() if v == task_text)
|
||||
table = _make_episode_table(episode_index, num_frames, fps=fps, task_index=task_index)
|
||||
path = data_dir / f"file-{episode_index:03d}.parquet"
|
||||
pq.write_table(table, path)
|
||||
|
||||
meta_dir = root / "meta"
|
||||
meta_dir.mkdir(parents=True, exist_ok=True)
|
||||
tasks_table = pa.Table.from_pydict(
|
||||
{
|
||||
"task_index": list(tasks.keys()),
|
||||
"task": list(tasks.values()),
|
||||
}
|
||||
)
|
||||
pq.write_table(tasks_table, meta_dir / "tasks.parquet")
|
||||
|
||||
info = {
|
||||
"codebase_version": "v3.1",
|
||||
"fps": fps,
|
||||
"total_episodes": len(episode_specs),
|
||||
}
|
||||
(meta_dir / "info.json").write_text(json.dumps(info, indent=2))
|
||||
|
||||
return root
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixture_dataset_root(tmp_path: Path) -> Path:
|
||||
"""A tiny dataset with two episodes, 12 frames each at 10 fps."""
|
||||
return _build_dataset(
|
||||
tmp_path / "ds",
|
||||
episode_specs=[
|
||||
(0, 12, "Could you tidy the kitchen please?"),
|
||||
(1, 12, "Please clean up the kitchen"),
|
||||
],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_episode_root(tmp_path: Path) -> Path:
|
||||
return _build_dataset(
|
||||
tmp_path / "ds_one",
|
||||
episode_specs=[(0, 30, "Pour water from the bottle into the cup.")],
|
||||
fps=10,
|
||||
)
|
||||
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Opt-in E2E smoke run for ``make annotation-e2e``.
|
||||
|
||||
Builds the same fixture used by the pytest suite, runs the full
|
||||
annotation pipeline against it with a stub VLM, and prints a short report.
|
||||
This is intentionally not a pytest test — it exercises the CLI plumbing
|
||||
without depending on conftest.py fixtures.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import AnnotationPipelineConfig
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
|
||||
|
||||
def _build_dataset(root: Path) -> Path:
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
n = 30
|
||||
timestamps = [round(i / 10, 6) for i in range(n)]
|
||||
table = pa.Table.from_pydict(
|
||||
{
|
||||
"episode_index": [0] * n,
|
||||
"frame_index": list(range(n)),
|
||||
"timestamp": timestamps,
|
||||
"task_index": [0] * n,
|
||||
"subtask_index": [0] * n,
|
||||
}
|
||||
)
|
||||
pq.write_table(table, data_dir / "file-000.parquet")
|
||||
meta = root / "meta"
|
||||
meta.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(
|
||||
pa.Table.from_pydict({"task_index": [0], "task": ["Pour water into the cup."]}),
|
||||
meta / "tasks.parquet",
|
||||
)
|
||||
(meta / "info.json").write_text(json.dumps({"codebase_version": "v3.1", "fps": 10}))
|
||||
return root
|
||||
|
||||
|
||||
def _stub_responder(messages):
|
||||
text = ""
|
||||
for m in messages:
|
||||
if m.get("role") == "user":
|
||||
content = m.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
if "atomic subtasks" in text:
|
||||
return {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 1.0},
|
||||
{"text": "pour into the cup", "start": 1.0, "end": 2.0},
|
||||
{"text": "place the bottle down", "start": 2.0, "end": 3.0},
|
||||
]
|
||||
}
|
||||
if "concise hierarchical PLAN" in text:
|
||||
return {"plan": "1. grasp\n2. pour\n3. place"}
|
||||
if "Update the memory" in text:
|
||||
return {"memory": "poured once"}
|
||||
if "acknowledgement the robot" in text:
|
||||
return {"text": "Sure."}
|
||||
if "ONE realistic interruption" in text:
|
||||
return {"interjection": "use less water", "speech": "Using less water."}
|
||||
if "frame-grounded visual question" in text:
|
||||
return {"question": "How many cups?", "answer": {"label": "cup", "count": 1}}
|
||||
return None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = _build_dataset(Path(tmp) / "ds")
|
||||
vlm = StubVlmClient(responder=_stub_responder)
|
||||
cfg = AnnotationPipelineConfig()
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=cfg.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=cfg.module_2, seed=cfg.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
summary = executor.run(root)
|
||||
print(f"phases={[(p.name, p.episodes_processed) for p in summary.phases]}")
|
||||
print(f"validation: {summary.validation_report.summary()}")
|
||||
print(f"shards rewritten: {len(summary.written_paths)}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module 1/2/3 unit tests with stubbed VLMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.vlm_client import StubVlmClient
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
sentinel: Any = field(default_factory=lambda: object())
|
||||
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return list(self.cameras)
|
||||
|
||||
def frames_at(self, record, timestamps, camera_key=None):
|
||||
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||
n = min(max_frames, len(record.frame_timestamps))
|
||||
return [self.sentinel] * n
|
||||
|
||||
|
||||
def _spy_responder(captured: list[list[dict[str, Any]]], reply: Any):
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
return reply
|
||||
|
||||
return StubVlmClient(responder=responder)
|
||||
|
||||
|
||||
def test_module1_plan_memory_subtask_smoke(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.4},
|
||||
{"text": "wipe the counter from left to right", "start": 0.4, "end": 0.8},
|
||||
{"text": "place the sponge into the sink", "start": 0.8, "end": 1.1},
|
||||
]
|
||||
},
|
||||
"concise hierarchical PLAN": {"plan": "1. grasp\n2. wipe\n3. place"},
|
||||
"Update the memory": {"memory": "wiped the counter once"},
|
||||
},
|
||||
)
|
||||
module = PlanSubtasksMemoryModule(vlm=vlm, config=Module1Config())
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_1")
|
||||
|
||||
styles = {r["style"] for r in rows}
|
||||
assert {"subtask", "plan", "memory"}.issubset(styles)
|
||||
# subtask timestamps must be exact frame timestamps
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for row in rows:
|
||||
assert row["timestamp"] in frame_set
|
||||
# exactly one plan row at t0
|
||||
plan_rows = [r for r in rows if r["style"] == "plan"]
|
||||
assert len(plan_rows) == 1
|
||||
assert plan_rows[0]["timestamp"] == record.frame_timestamps[0]
|
||||
|
||||
|
||||
def test_module2_at_t0_emits_speech_only_no_interjection(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{"acknowledgement the robot": {"text": "Sure, on it."}},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=0),
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
assert len(rows) == 1
|
||||
only = rows[0]
|
||||
assert only["role"] == "assistant"
|
||||
assert only["style"] is None
|
||||
assert only["content"] is None
|
||||
assert only["timestamp"] == record.frame_timestamps[0]
|
||||
assert only["tool_calls"][0]["function"]["name"] == "say"
|
||||
|
||||
|
||||
def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
fixture_dataset_root: Path, tmp_path: Path
|
||||
) -> None:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"acknowledgement the robot": {"text": "OK."},
|
||||
"ONE realistic interruption": {
|
||||
"interjection": "actually skip the dishes",
|
||||
"speech": "Skipping the dishes.",
|
||||
},
|
||||
},
|
||||
)
|
||||
module = InterjectionsAndSpeechModule(
|
||||
vlm=vlm,
|
||||
config=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.2),
|
||||
seed=7,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_2")
|
||||
|
||||
interjections = [r for r in rows if r["style"] == "interjection"]
|
||||
speeches = [r for r in rows if r["style"] is None and r["role"] == "assistant"]
|
||||
assert len(interjections) == 1
|
||||
assert len(speeches) >= 2 # initial t=0 + one paired with the interjection
|
||||
inter_t = interjections[0]["timestamp"]
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(
|
||||
cameras=("observation.images.top", "observation.images.wrist")
|
||||
),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||
user_keys = [
|
||||
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"
|
||||
]
|
||||
assistant_keys = [
|
||||
(r["timestamp"], r["camera"])
|
||||
for r in rows
|
||||
if r["role"] == "assistant" and r["style"] == "vqa"
|
||||
]
|
||||
assert len(user_keys) == len(set(user_keys))
|
||||
assert len(assistant_keys) == len(set(assistant_keys))
|
||||
# both cameras must be represented
|
||||
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts, _ in user_keys + assistant_keys:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
def test_module1_attaches_video_block_to_subtask_prompt(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Module 1 sends one ``type=video`` block covering the whole episode."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"subtasks": [
|
||||
{"text": "grasp the handle of the sponge", "start": 0.0, "end": 0.5},
|
||||
{"text": "wipe the counter", "start": 0.5, "end": 1.1},
|
||||
]
|
||||
}
|
||||
plan_payload = {"plan": "1. grasp\n2. wipe"}
|
||||
memory_payload = {"memory": "wiped once"}
|
||||
|
||||
def responder(messages):
|
||||
captured.append(list(messages))
|
||||
text = ""
|
||||
for m in messages:
|
||||
for block in m.get("content", []):
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if "concise hierarchical PLAN" in text:
|
||||
return plan_payload
|
||||
if "Update the memory" in text:
|
||||
return memory_payload
|
||||
return payload
|
||||
|
||||
provider = _StubFrameProvider()
|
||||
module = PlanSubtasksMemoryModule(
|
||||
vlm=StubVlmClient(responder=responder),
|
||||
config=Module1Config(max_video_frames=5, frames_per_second=10.0),
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(fixture_dataset_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
# the subtask call (the first VLM call) must carry exactly one video block
|
||||
assert captured, "no VLM calls made"
|
||||
first_call = captured[0]
|
||||
content = first_call[0]["content"]
|
||||
video_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "video"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(video_blocks) == 1, f"expected exactly 1 video block, got {content}"
|
||||
assert image_blocks == [], "subtask prompt must not mix image blocks with the video block"
|
||||
assert len(text_blocks) == 1
|
||||
# video block must wrap a list of frames covering the episode
|
||||
assert isinstance(video_blocks[0]["video"], list)
|
||||
assert len(video_blocks[0]["video"]) <= 5
|
||||
# provider is called with target_count = min(duration * fps, max). With
|
||||
# fps=10 on a ~1s episode that requests >max, so max=5 wins.
|
||||
assert provider.video_calls and provider.video_calls[0][0] == record.episode_index
|
||||
assert provider.video_calls[0][1] <= 5
|
||||
|
||||
|
||||
def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
"""Each VQA prompt must carry a single image block at the emission frame."""
|
||||
captured: list[list[dict[str, Any]]] = []
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
}
|
||||
provider = _StubFrameProvider()
|
||||
module = GeneralVqaModule(
|
||||
vlm=_spy_responder(captured, payload),
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=1),
|
||||
seed=0,
|
||||
frame_provider=provider,
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
|
||||
assert captured, "no VLM calls made"
|
||||
for messages in captured:
|
||||
content = messages[0]["content"]
|
||||
image_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "image"]
|
||||
text_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission per camera with the exact emission timestamp
|
||||
for ep_idx, ts_tuple, camera in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
assert camera in provider.cameras
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "Where is the cup?",
|
||||
"answer": {"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [10, 20, 50, 80]}]},
|
||||
}
|
||||
vlm = make_canned_responder({"frame-grounded visual question": payload})
|
||||
module = GeneralVqaModule(
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
for row in rows:
|
||||
if row["role"] == "assistant" and row["style"] == "vqa":
|
||||
decoded = json.loads(row["content"])
|
||||
assert "detections" in decoded
|
||||
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end smoke: pipeline output → PR 1 canonical recipe rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.config import (
|
||||
AnnotationPipelineConfig,
|
||||
Module1Config,
|
||||
Module2Config,
|
||||
Module3Config,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.executor import Executor
|
||||
from lerobot.annotations.steerable_pipeline.modules import (
|
||||
GeneralVqaModule,
|
||||
InterjectionsAndSpeechModule,
|
||||
PlanSubtasksMemoryModule,
|
||||
)
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import LanguageColumnsWriter
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
_RECIPE_PATH = (
|
||||
Path(__file__).resolve().parents[2] / "src" / "lerobot" / "configs" / "recipes" / "pi05_hirobot.yaml"
|
||||
)
|
||||
|
||||
|
||||
def _build_executor() -> Executor:
|
||||
vlm = make_canned_responder(
|
||||
{
|
||||
"atomic subtasks": {
|
||||
"subtasks": [
|
||||
{"text": "grasp the bottle", "start": 0.0, "end": 0.5},
|
||||
{"text": "pour into the cup", "start": 0.5, "end": 1.0},
|
||||
{"text": "place the bottle down", "start": 1.0, "end": 1.5},
|
||||
]
|
||||
},
|
||||
"concise hierarchical PLAN": {"plan": "1. grasp\n2. pour\n3. place"},
|
||||
"Update the memory": {"memory": "poured once"},
|
||||
"acknowledgement the robot": {"text": "Sure."},
|
||||
"ONE realistic interruption": {
|
||||
"interjection": "use less water",
|
||||
"speech": "Using less water.",
|
||||
},
|
||||
"frame-grounded visual question": {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 1},
|
||||
},
|
||||
},
|
||||
)
|
||||
config = AnnotationPipelineConfig(
|
||||
module_1=Module1Config(),
|
||||
module_2=Module2Config(max_interjections_per_episode=1, interjection_min_t=0.5),
|
||||
module_3=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
)
|
||||
return Executor(
|
||||
config=config,
|
||||
module_1=PlanSubtasksMemoryModule(vlm=vlm, config=config.module_1),
|
||||
module_2=InterjectionsAndSpeechModule(vlm=vlm, config=config.module_2, seed=config.seed),
|
||||
module_3=GeneralVqaModule(vlm=vlm, config=config.module_3, seed=config.seed),
|
||||
writer=LanguageColumnsWriter(),
|
||||
validator=StagingValidator(),
|
||||
)
|
||||
|
||||
|
||||
def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
single_episode_root: Path,
|
||||
) -> None:
|
||||
executor = _build_executor()
|
||||
summary = executor.run(single_episode_root)
|
||||
# validator may emit warnings but no errors for the synthetic fixture
|
||||
assert summary.validation_report.ok, summary.validation_report.summary()
|
||||
|
||||
table = pq.read_table(single_episode_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent_lists = table.column("language_persistent").to_pylist()
|
||||
events_lists = table.column("language_events").to_pylist()
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
|
||||
recipe = TrainingRecipe.from_yaml(_RECIPE_PATH) if hasattr(TrainingRecipe, "from_yaml") else None
|
||||
if recipe is None:
|
||||
# PR 1 may not expose from_yaml; load via PyYAML and TrainingRecipe(**...)
|
||||
import yaml
|
||||
|
||||
loaded = yaml.safe_load(_RECIPE_PATH.read_text(encoding="utf-8"))
|
||||
recipe = TrainingRecipe(**loaded)
|
||||
|
||||
rendered_any = False
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
speech_rows = [r for r in flat_events if r.get("style") is None and r.get("role") == "assistant"]
|
||||
assert speech_rows
|
||||
say = speech_rows[0]["tool_calls"][0]
|
||||
assert say["function"]["name"] == "say"
|
||||
assert isinstance(say["function"]["arguments"]["text"], str)
|
||||
# PR 2 no longer writes a ``tools`` column — the say schema lives as a
|
||||
# constant (``SAY_TOOL_SCHEMA``) so PR 1's row struct is the single
|
||||
# source of truth for the v3.1 schema.
|
||||
assert "tools" not in table.column_names
|
||||
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Validator behavior tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.validator import StagingValidator
|
||||
from lerobot.annotations.steerable_pipeline.writer import speech_atom
|
||||
|
||||
|
||||
def _validate(root: Path, staging_dir: Path):
|
||||
records = list(iter_episodes(root))
|
||||
return StagingValidator().validate(records, staging_dir)
|
||||
|
||||
|
||||
def test_validator_catches_misaligned_timestamps(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_3",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps({"label": "cup", "count": 2}, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": 9.999, # not on any 10 fps frame
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("does not match any source frame timestamp" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_orphan_speech(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
# interjection at 0.3s with NO paired speech
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip it",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("paired speech" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_inconsistent_plan_memory(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do x",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do x",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_2",
|
||||
[
|
||||
speech_atom(0.0, "Got it."),
|
||||
speech_atom(0.4, "Replanning."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "replan",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
# missing co-timestamped plan refresh at 0.4s → error
|
||||
assert not report.ok
|
||||
assert any("co-timestamped plan update" in e for e in report.errors)
|
||||
|
||||
|
||||
def test_validator_catches_wrong_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
EpisodeStaging(staging_dir, 0).write(
|
||||
"module_1",
|
||||
[
|
||||
{"role": "user", "content": "where?", "style": "vqa", "timestamp": 0.0, "tool_calls": None},
|
||||
],
|
||||
)
|
||||
report = _validate(fixture_dataset_root, staging_dir)
|
||||
assert not report.ok
|
||||
assert any("module_1 emitted style 'vqa'" in e or "must be persistent" in e for e in report.errors)
|
||||
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Writer correctness tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes
|
||||
from lerobot.annotations.steerable_pipeline.staging import EpisodeStaging
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
LanguageColumnsWriter,
|
||||
speech_atom,
|
||||
)
|
||||
|
||||
|
||||
def _stage_episode(
|
||||
staging_dir: Path,
|
||||
episode_index: int,
|
||||
*,
|
||||
module_1: list[dict] | None = None,
|
||||
module_2: list[dict] | None = None,
|
||||
module_3: list[dict] | None = None,
|
||||
) -> None:
|
||||
staging = EpisodeStaging(staging_dir, episode_index)
|
||||
if module_1 is not None:
|
||||
staging.write("module_1", module_1)
|
||||
if module_2 is not None:
|
||||
staging.write("module_2", module_2)
|
||||
if module_3 is not None:
|
||||
staging.write("module_3", module_3)
|
||||
|
||||
|
||||
def test_writer_persistence_identity(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Every frame in an episode has a byte-identical persistent list."""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "grasp the sponge",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. wipe\n2. dry",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "wiped the counter",
|
||||
"style": "memory",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
persistent = table.column("language_persistent").to_pylist()
|
||||
first = persistent[0]
|
||||
assert first # non-empty
|
||||
for row in persistent:
|
||||
assert row == first, "persistent slice must be byte-identical across all frames"
|
||||
|
||||
|
||||
def test_writer_events_exact_timestamp(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_2=[
|
||||
speech_atom(0.0, "Got it."),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "skip the dishes",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.5,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.5, "Skipping the dishes."),
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
timestamps = table.column("timestamp").to_pylist()
|
||||
events = table.column("language_events").to_pylist()
|
||||
for ts, ev in zip(timestamps, events, strict=True):
|
||||
if abs(ts - 0.0) < 1e-9:
|
||||
assert any(r["role"] == "assistant" and r.get("style") is None for r in ev), ev
|
||||
elif abs(ts - 0.5) < 1e-9:
|
||||
assert any(r.get("style") == "interjection" for r in ev), ev
|
||||
assert any(r.get("style") is None for r in ev), ev
|
||||
else:
|
||||
assert ev == []
|
||||
|
||||
|
||||
def test_writer_column_routing(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "1. do X",
|
||||
"style": "plan",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "did X",
|
||||
"style": "memory",
|
||||
"timestamp": 0.3,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
module_2=[
|
||||
speech_atom(0.0, "OK"),
|
||||
{
|
||||
"role": "user",
|
||||
"content": "wait",
|
||||
"style": "interjection",
|
||||
"timestamp": 0.2,
|
||||
"tool_calls": None,
|
||||
},
|
||||
speech_atom(0.2, "Waiting"),
|
||||
],
|
||||
module_3=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "where is the cup?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": json.dumps(
|
||||
{"detections": [{"label": "cup", "bbox_format": "xyxy", "bbox": [1, 2, 3, 4]}]},
|
||||
sort_keys=True,
|
||||
),
|
||||
"style": "vqa",
|
||||
"timestamp": 0.4,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
|
||||
persistent = table.column("language_persistent").to_pylist()[0]
|
||||
persistent_styles = {r["style"] for r in persistent}
|
||||
assert persistent_styles == {"subtask", "plan", "memory"}
|
||||
|
||||
all_events = [r for ev in table.column("language_events").to_pylist() for r in ev]
|
||||
event_styles = {r.get("style") for r in all_events}
|
||||
assert event_styles == {None, "interjection", "vqa"}
|
||||
|
||||
|
||||
def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "do X",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
writer = LanguageColumnsWriter()
|
||||
writer.write_all(records, staging_dir, fixture_dataset_root)
|
||||
|
||||
path = fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet"
|
||||
table_a = pq.read_table(path)
|
||||
assert "subtask_index" not in table_a.column_names
|
||||
assert "language_persistent" in table_a.column_names
|
||||
assert "language_events" in table_a.column_names
|
||||
# The writer no longer emits a dataset-level ``tools`` column; the
|
||||
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
|
||||
# so the parquet stays small and PR 2 doesn't extend PR 1's schema.
|
||||
assert "tools" not in table_a.column_names
|
||||
|
||||
# second pass — must produce identical bytes for the language columns
|
||||
records_again = list(iter_episodes(fixture_dataset_root))
|
||||
writer.write_all(records_again, staging_dir, fixture_dataset_root)
|
||||
table_b = pq.read_table(path)
|
||||
assert (
|
||||
table_a.column("language_persistent").to_pylist() == table_b.column("language_persistent").to_pylist()
|
||||
)
|
||||
assert table_a.column("language_events").to_pylist() == table_b.column("language_events").to_pylist()
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_persistent_style() -> None:
|
||||
"""``_normalize_persistent_row`` must reject any non-persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_persistent_row
|
||||
|
||||
with pytest.raises(ValueError, match="non-persistent style"):
|
||||
_normalize_persistent_row(
|
||||
{"role": "assistant", "content": "oops", "style": "vqa", "timestamp": 0.0, "tool_calls": None}
|
||||
)
|
||||
|
||||
|
||||
def test_writer_normalize_rejects_misrouted_event_style() -> None:
|
||||
"""``_normalize_event_row`` must reject any persistent style."""
|
||||
from lerobot.annotations.steerable_pipeline.writer import _normalize_event_row
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
|
||||
|
||||
|
||||
def test_say_tool_schema_constant_is_well_formed() -> None:
|
||||
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
|
||||
``tools`` column — chat-template consumers import them directly.
|
||||
"""
|
||||
from lerobot.annotations.steerable_pipeline.writer import (
|
||||
DEFAULT_TOOLS,
|
||||
SAY_TOOL_SCHEMA,
|
||||
)
|
||||
|
||||
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
|
||||
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
|
||||
params = SAY_TOOL_SCHEMA["function"]["parameters"]
|
||||
assert params["properties"]["text"]["type"] == "string"
|
||||
assert params["required"] == ["text"]
|
||||
|
||||
|
||||
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
|
||||
"""Re-running on a parquet that already has a legacy ``tools`` column
|
||||
must drop it cleanly so reruns converge to the v3.1 schema.
|
||||
"""
|
||||
staging_dir = tmp_path / "stage"
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
0,
|
||||
module_1=[
|
||||
{"role": "assistant", "content": "x", "style": "subtask", "timestamp": 0.0, "tool_calls": None}
|
||||
],
|
||||
)
|
||||
records = list(iter_episodes(fixture_dataset_root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
|
||||
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
|
||||
assert "tools" not in table.column_names
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
assert atom["style"] is None
|
||||
assert atom["content"] is None
|
||||
assert atom["timestamp"] == 2.5
|
||||
assert isinstance(atom["tool_calls"], list)
|
||||
call = atom["tool_calls"][0]
|
||||
assert call["type"] == "function"
|
||||
assert call["function"]["name"] == "say"
|
||||
assert call["function"]["arguments"]["text"] == "I'm cleaning up!"
|
||||
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
|
||||
|
||||
def test_message_recipe_validates_unknown_binding():
|
||||
with pytest.raises(ValueError, match="unknown binding"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_canonical_recipe_loads():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
|
||||
assert recipe.blend is not None
|
||||
assert set(recipe.blend) == {
|
||||
"memory_update",
|
||||
"user_interjection_response",
|
||||
"high_level_subtask",
|
||||
"low_level_execution",
|
||||
"ask_vqa_top",
|
||||
"ask_vqa_wrist",
|
||||
}
|
||||
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
|
||||
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.io_utils import write_info
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
VIEW_DEPENDENT_STYLES,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
language_events_arrow_type,
|
||||
language_feature_info,
|
||||
language_persistent_arrow_type,
|
||||
validate_camera_field,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
||||
|
||||
|
||||
def test_language_arrow_schema_has_expected_fields():
|
||||
persistent_row_type = language_persistent_arrow_type().value_type
|
||||
event_row_type = language_events_arrow_type().value_type
|
||||
|
||||
assert isinstance(persistent_row_type, pa.StructType)
|
||||
assert persistent_row_type.names == [
|
||||
"role",
|
||||
"content",
|
||||
"style",
|
||||
"timestamp",
|
||||
"camera",
|
||||
"tool_calls",
|
||||
]
|
||||
|
||||
assert isinstance(event_row_type, pa.StructType)
|
||||
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
|
||||
|
||||
|
||||
def test_style_registry_routes_columns():
|
||||
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
|
||||
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
||||
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||
|
||||
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||
assert column_for_style("trace") == LANGUAGE_EVENTS
|
||||
assert column_for_style(None) == LANGUAGE_EVENTS
|
||||
|
||||
|
||||
def test_view_dependent_styles():
|
||||
# motion lives in PERSISTENT_STYLES and is described in robot-frame
|
||||
# (joint / Cartesian) terms, so it is NOT view-dependent. Only vqa
|
||||
# (event) and trace (event, pixel-trajectory) carry a camera tag.
|
||||
assert {"vqa", "trace"} == VIEW_DEPENDENT_STYLES
|
||||
assert is_view_dependent_style("vqa")
|
||||
assert is_view_dependent_style("trace")
|
||||
assert not is_view_dependent_style("motion")
|
||||
assert not is_view_dependent_style("subtask")
|
||||
assert not is_view_dependent_style("plan")
|
||||
assert not is_view_dependent_style("interjection")
|
||||
assert not is_view_dependent_style(None)
|
||||
|
||||
|
||||
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
|
||||
validate_camera_field("vqa", "observation.images.top")
|
||||
validate_camera_field("trace", "observation.images.front")
|
||||
with pytest.raises(ValueError, match="view-dependent"):
|
||||
validate_camera_field("vqa", None)
|
||||
with pytest.raises(ValueError, match="view-dependent"):
|
||||
validate_camera_field("trace", "")
|
||||
|
||||
|
||||
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
|
||||
validate_camera_field("subtask", None)
|
||||
validate_camera_field("plan", None)
|
||||
validate_camera_field("memory", None)
|
||||
validate_camera_field("motion", None)
|
||||
validate_camera_field("interjection", None)
|
||||
validate_camera_field(None, None)
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("subtask", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("motion", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field("interjection", "observation.images.top")
|
||||
with pytest.raises(ValueError, match="must have camera=None"):
|
||||
validate_camera_field(None, "observation.images.top")
|
||||
|
||||
|
||||
def test_unknown_style_rejected():
|
||||
with pytest.raises(ValueError, match="Unknown language style"):
|
||||
column_for_style("surprise")
|
||||
|
||||
|
||||
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
|
||||
root = tmp_path / "language_dataset"
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=root,
|
||||
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
|
||||
use_videos=False,
|
||||
)
|
||||
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
persistent = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
]
|
||||
event = {
|
||||
"role": "user",
|
||||
"content": "what is visible?",
|
||||
"style": "vqa",
|
||||
"camera": "observation.images.top",
|
||||
"tool_calls": None,
|
||||
}
|
||||
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
df = pd.read_parquet(data_path)
|
||||
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
|
||||
df[LANGUAGE_EVENTS] = [[event], []]
|
||||
df.to_parquet(data_path)
|
||||
|
||||
info = dataset.meta.info
|
||||
info["features"].update(language_feature_info())
|
||||
write_info(info, root)
|
||||
|
||||
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
|
||||
|
||||
first = reloaded[0]
|
||||
second = reloaded[1]
|
||||
assert first[LANGUAGE_PERSISTENT] == persistent
|
||||
assert first[LANGUAGE_EVENTS] == [event]
|
||||
assert second[LANGUAGE_PERSISTENT] == persistent
|
||||
assert second[LANGUAGE_EVENTS] == []
|
||||
@@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
||||
|
||||
|
||||
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"timestamp": timestamp,
|
||||
"camera": camera,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
def event_row(role, content, style, tool_calls=None, camera=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"camera": camera,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
PERSISTENT = [
|
||||
persistent_row("assistant", "plan 0", "plan", 0.0),
|
||||
persistent_row("assistant", "memory 0", "memory", 0.0),
|
||||
persistent_row("assistant", "subtask 0", "subtask", 0.0),
|
||||
persistent_row("assistant", "memory 1", "memory", 1.0),
|
||||
persistent_row("assistant", "subtask 1", "subtask", 1.0),
|
||||
]
|
||||
EVENTS_AT_1 = [
|
||||
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
|
||||
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
|
||||
]
|
||||
EVENTS_AT_2 = [
|
||||
event_row("user", "skip wiping", "interjection"),
|
||||
event_row(
|
||||
"assistant",
|
||||
None,
|
||||
None,
|
||||
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
||||
),
|
||||
]
|
||||
# Same emission tick, two cameras: triggers per-camera disambiguation in
|
||||
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
|
||||
# (vqa, user) + (vqa, assistant) pair per camera.
|
||||
EVENTS_AT_3_TWO_CAMERAS = [
|
||||
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
|
||||
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
|
||||
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
|
||||
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
|
||||
]
|
||||
|
||||
|
||||
def test_resolver_temporal_semantics():
|
||||
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
||||
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
||||
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
|
||||
assert (
|
||||
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
|
||||
== '{"count": 2}'
|
||||
)
|
||||
|
||||
|
||||
def test_persistent_relative_resolvers_reject_event_styles():
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
active_at(1.0, persistent=PERSISTENT, style="vqa")
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
|
||||
|
||||
|
||||
def test_nth_prev_and_next():
|
||||
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
|
||||
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
|
||||
|
||||
|
||||
def test_substitution_if_present_multimodal_and_tool_calls():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${task}: ${interjection}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
bindings={"plan": "active_at(t, style=plan)"},
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_2,
|
||||
t=2.0,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
|
||||
assert rendered["messages"][1]["content"] == "plan 0"
|
||||
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
|
||||
assert rendered["message_streams"] == ["high_level", "high_level"]
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_exact_event_miss_returns_none_when_target_skips():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
|
||||
)
|
||||
|
||||
|
||||
def test_deterministic_blend_sampling():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"a": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"b": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
first = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
second = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_emitted_at_filters_vqa_by_camera():
|
||||
top = emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
camera="observation.images.top",
|
||||
)
|
||||
wrist = emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
camera="observation.images.wrist",
|
||||
)
|
||||
assert top["content"] == '{"count": 3}'
|
||||
assert wrist["content"] == '{"count": 1}'
|
||||
|
||||
|
||||
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
||||
with pytest.raises(ValueError, match="Ambiguous resolver"):
|
||||
emitted_at(
|
||||
3.0,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
style="vqa",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def test_per_camera_blend_renders_both_views():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"top": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
),
|
||||
"wrist": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": (
|
||||
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
),
|
||||
"vqa": (
|
||||
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.wrist"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
rendered_top = render_sample(
|
||||
recipe=recipe.blend["top"],
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
rendered_wrist = render_sample(
|
||||
recipe=recipe.blend["wrist"],
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
|
||||
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
|
||||
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
|
||||
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
|
||||
|
||||
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
|
||||
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
|
||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||
|
||||
|
||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
rephrasings = [
|
||||
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
|
||||
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
|
||||
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
|
||||
# No explicit task override → resolver consults persistent rows.
|
||||
seen: set[str] = set()
|
||||
for sample_idx in range(64):
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=sample_idx,
|
||||
dataset_ctx={"task": "canonical kitchen task"},
|
||||
)
|
||||
seen.add(rendered["messages"][0]["content"])
|
||||
# Every rephrasing should be reachable across enough samples.
|
||||
assert seen == {r["content"] for r in rephrasings}
|
||||
# Same sample_idx → same pick (determinism).
|
||||
a = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
b = render_sample(
|
||||
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||
|
||||
|
||||
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT, # no task_aug rows
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "clean the kitchen"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "clean the kitchen"
|
||||
|
||||
|
||||
def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
rephrasings = [
|
||||
persistent_row("user", "rephrased one", "task_aug", 0.0),
|
||||
persistent_row("user", "rephrased two", "task_aug", 0.0),
|
||||
]
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=rephrasings,
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="explicit override wins",
|
||||
dataset_ctx={"task": "canonical"},
|
||||
)
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=low_level,
|
||||
persistent=PERSISTENT,
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
||||
assert rendered["message_streams"][-1] == "low_level"
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
@@ -1,193 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for subtask functionality in LeRobotDataset.
|
||||
|
||||
These tests verify that:
|
||||
- Subtask information is correctly loaded from datasets that have subtask data
|
||||
- The __getitem__ method correctly adds subtask strings to returned items
|
||||
- Subtask handling gracefully handles missing data
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class TestSubtaskDataset:
|
||||
"""Tests for subtask handling in LeRobotDataset."""
|
||||
|
||||
@pytest.fixture
|
||||
def subtask_dataset(self):
|
||||
"""Load the test subtask dataset from the hub."""
|
||||
# Use lerobot/pusht-subtask dataset with episode 1
|
||||
return LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
|
||||
def test_subtask_dataset_loads(self, subtask_dataset):
|
||||
"""Test that the subtask dataset loads successfully."""
|
||||
assert subtask_dataset is not None
|
||||
assert len(subtask_dataset) > 0
|
||||
|
||||
def test_subtask_metadata_loaded(self, subtask_dataset):
|
||||
"""Test that subtask metadata is loaded when present in dataset."""
|
||||
# The dataset should have subtasks metadata loaded
|
||||
assert subtask_dataset.meta.subtasks is not None
|
||||
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
|
||||
|
||||
def test_subtask_index_in_features(self, subtask_dataset):
|
||||
"""Test that subtask_index is a feature when dataset has subtasks."""
|
||||
assert "subtask_index" in subtask_dataset.features
|
||||
|
||||
def test_getitem_returns_subtask_string(self, subtask_dataset):
|
||||
"""Test that __getitem__ correctly adds subtask string to returned item."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Subtask should be present in the returned item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
assert len(item["subtask"]) > 0 # Should not be empty
|
||||
|
||||
def test_getitem_has_subtask_index(self, subtask_dataset):
|
||||
"""Test that __getitem__ includes subtask_index."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
assert "subtask_index" in item
|
||||
assert isinstance(item["subtask_index"], torch.Tensor)
|
||||
|
||||
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
|
||||
"""Test that subtask_index correctly maps to a subtask in metadata."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
assert item["subtask"] == subtask_from_metadata
|
||||
|
||||
def test_all_items_have_subtask(self, subtask_dataset):
|
||||
"""Test that all items in the dataset have subtask information."""
|
||||
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
|
||||
item = subtask_dataset[i]
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
def test_task_and_subtask_coexist(self, subtask_dataset):
|
||||
"""Test that both task and subtask are present in returned items."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Both task and subtask should be present
|
||||
assert "task" in item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["task"], str)
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
|
||||
class TestSubtaskDatasetMissing:
|
||||
"""Tests for graceful handling when subtask data is missing."""
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Create a dataset without subtask information."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
|
||||
|
||||
# Add some frames and save
|
||||
for _ in range(5):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Reload the dataset
|
||||
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
def test_no_subtask_in_features(self, dataset_without_subtasks):
|
||||
"""Test that subtask_index is not in features when not provided."""
|
||||
assert "subtask_index" not in dataset_without_subtasks.features
|
||||
|
||||
def test_getitem_without_subtask(self, dataset_without_subtasks):
|
||||
"""Test that __getitem__ works when subtask is not present."""
|
||||
item = dataset_without_subtasks[0]
|
||||
|
||||
# Item should still be retrievable
|
||||
assert item is not None
|
||||
assert "state" in item
|
||||
assert "task" in item
|
||||
|
||||
# Subtask should NOT be present
|
||||
assert "subtask" not in item
|
||||
|
||||
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
|
||||
"""Test that subtasks metadata is None when not present."""
|
||||
assert dataset_without_subtasks.meta.subtasks is None
|
||||
|
||||
|
||||
class TestSubtaskEdgeCases:
|
||||
"""Edge case tests for subtask handling."""
|
||||
|
||||
def test_subtask_with_multiple_episodes(self):
|
||||
"""Test subtask handling with multiple episodes if available."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
# Check first and last items have valid subtasks
|
||||
first_item = dataset[0]
|
||||
last_item = dataset[len(dataset) - 1]
|
||||
|
||||
assert "subtask" in first_item
|
||||
assert "subtask" in last_item
|
||||
assert isinstance(first_item["subtask"], str)
|
||||
assert isinstance(last_item["subtask"], str)
|
||||
|
||||
def test_subtask_index_consistency(self):
|
||||
"""Test that same subtask_index returns same subtask string."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
if len(dataset) < 2:
|
||||
pytest.skip("Dataset too small for this test")
|
||||
|
||||
# Collect subtask_index to subtask mappings
|
||||
subtask_map = {}
|
||||
for i in range(min(len(dataset), 10)):
|
||||
item = dataset[i]
|
||||
idx = item["subtask_index"].item()
|
||||
subtask = item["subtask"]
|
||||
|
||||
if idx in subtask_map:
|
||||
# Same index should always return same subtask
|
||||
assert subtask_map[idx] == subtask, (
|
||||
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
|
||||
)
|
||||
else:
|
||||
subtask_map[idx] = subtask
|
||||
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.processor.converters import create_transition
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
|
||||
def test_render_messages_step_noops_without_language_columns():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "do it"})
|
||||
|
||||
assert RenderMessagesStep(recipe)(transition) == transition
|
||||
|
||||
|
||||
def test_render_messages_step_renders_and_drops_raw_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "do it",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach carefully",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"camera": None,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
"language_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "language_persistent" not in data
|
||||
assert "language_events" not in data
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
|
||||
|
||||
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
|
||||
batch = [
|
||||
{
|
||||
"index": torch.tensor(0),
|
||||
"messages": [{"role": "assistant", "content": "a"}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [0],
|
||||
"language_persistent": [{"content": "raw"}],
|
||||
"language_events": [],
|
||||
},
|
||||
{
|
||||
"index": torch.tensor(1),
|
||||
"messages": [{"role": "assistant", "content": "b"}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [0],
|
||||
"language_persistent": [{"content": "raw"}],
|
||||
"language_events": [],
|
||||
},
|
||||
]
|
||||
|
||||
out = lerobot_collate_fn(batch)
|
||||
|
||||
assert out["index"].tolist() == [0, 1]
|
||||
assert out["messages"][0][0]["content"] == "a"
|
||||
assert out["messages"][1][0]["content"] == "b"
|
||||
assert out["message_streams"] == [["low_level"], ["low_level"]]
|
||||
assert out["target_message_indices"] == [[0], [0]]
|
||||
assert "language_persistent" not in out
|
||||
assert "language_events" not in out
|
||||
Reference in New Issue
Block a user