mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 22:59:50 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 73780046b2 | |||
| 093a85f946 | |||
| a669049da2 | |||
| ce348a3460 | |||
| cb920235c4 | |||
| 7f40b3bf82 | |||
| 2e9c9fd832 | |||
| f9cb5e659c |
@@ -31,7 +31,8 @@ jobs:
|
||||
name: Upload Preview and Comment
|
||||
if: >
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: lerobot
|
||||
|
||||
@@ -42,7 +42,9 @@ jobs:
|
||||
# This job builds and deploys the official documentation.
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
@@ -58,7 +60,7 @@ jobs:
|
||||
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
||||
build_pr_docs:
|
||||
name: Build PR Docs
|
||||
if: github.event_name == 'pull_request'
|
||||
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
@@ -45,7 +45,6 @@ permissions:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
|
||||
@@ -43,6 +43,7 @@ jobs:
|
||||
name: Build CPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||
steps:
|
||||
@@ -77,6 +78,7 @@ jobs:
|
||||
name: Build GPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||
steps:
|
||||
|
||||
@@ -29,6 +29,7 @@ jobs:
|
||||
build-and-publish:
|
||||
name: Build and publish Python distributions
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
version: ${{ steps.extract_info.outputs.tag_version }}
|
||||
permissions:
|
||||
|
||||
@@ -45,6 +45,7 @@ jobs:
|
||||
stale:
|
||||
name: Close Stale Issues and PRs
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
actions: write
|
||||
contents: write # only for delete-branch option
|
||||
|
||||
@@ -43,6 +43,7 @@ jobs:
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
|
||||
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
#### Convert to Video
|
||||
|
||||
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||
|
||||
```bash
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
@@ -96,7 +159,7 @@ lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
+40
-82
@@ -24,7 +24,7 @@ Built from pure Transformer encoders, X-VLA scales naturally with model size and
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||
alt="XVLA Architecture 2"
|
||||
style="width: 32%; max-width: 450px; height: auto;"
|
||||
style="width: 60%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
@@ -120,7 +120,7 @@ Adapted for Google Robot platforms.
|
||||
|
||||
### Recommended Training Configuration
|
||||
|
||||
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
|
||||
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
@@ -129,25 +129,26 @@ lerobot-train \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=True \
|
||||
--policy.freeze_language_encoder=True \
|
||||
--policy.train_policy_transformer=True \
|
||||
--policy.train_soft_prompts=True \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--policy.action_mode=YOUR_ACTION_MODE
|
||||
```
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------- | ------- | ---------------------------------------- |
|
||||
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
|
||||
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
|
||||
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
|
||||
| `train_soft_prompts` | `True` | Allow soft prompts to train |
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------- | ------- | ---------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||
|
||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
|
||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||
|
||||
### Example: Training on Bimanual Robot
|
||||
|
||||
@@ -157,14 +158,15 @@ lerobot-train \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=so101_bimanual \
|
||||
--policy.freeze_vision_encoder=True \
|
||||
--policy.freeze_language_encoder=True \
|
||||
--policy.train_policy_transformer=True \
|
||||
--policy.train_soft_prompts=True
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true
|
||||
```
|
||||
|
||||
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||
@@ -172,71 +174,7 @@ lerobot-train \
|
||||
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||
|
||||
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
||||
To enable this behavior, you must:
|
||||
|
||||
1. Implement a custom optimizer and register it in your training config
|
||||
|
||||
```
|
||||
from dataclasses import dataclass, asdict
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
import torch
|
||||
|
||||
@OptimizerConfig.register_subclass("xvla-adamw")
|
||||
@dataclass
|
||||
class XVLAAdamW(OptimizerConfig):
|
||||
lr: float = 1e-4
|
||||
betas: tuple[float, float] = (0.9, 0.99)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Expect `named_parameters()` as input.
|
||||
Apply lr = lr / 10 for all VLM-related parameters.
|
||||
"""
|
||||
assert isinstance(params, dict), \
|
||||
"Custom LR optimizer requires `named_parameters()` as inputs."
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
vlm_group, other_group = [], []
|
||||
for name, p in params.items():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if "vlm" in name.lower():
|
||||
vlm_group.append(p)
|
||||
else:
|
||||
other_group.append(p)
|
||||
|
||||
param_groups = [
|
||||
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
|
||||
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
|
||||
]
|
||||
|
||||
return torch.optim.AdamW(param_groups, **kwargs)
|
||||
```
|
||||
|
||||
2. Modify X-VLA’s get_optim_params to return named parameters
|
||||
|
||||
Replace:
|
||||
|
||||
```
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return only trainable parameters for optimization."""
|
||||
return filter(lambda p: p.requires_grad, self.parameters())
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```
|
||||
def get_optim_params(self):
|
||||
"""Return trainable named parameters."""
|
||||
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
|
||||
```
|
||||
|
||||
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
|
||||
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||
❕Note
|
||||
|
||||
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
@@ -326,6 +264,26 @@ domain_id = 3
|
||||
|
||||
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||
|
||||
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||
|
||||
#### Fine-tuning Datasets
|
||||
|
||||
| Dataset Name | Domain ID |
|
||||
| ---------------- | --------- |
|
||||
| Bridge | 0 |
|
||||
| RT1 | 1 |
|
||||
| Calvin | 2 |
|
||||
| libero | 3 |
|
||||
| widowx-air | 4 |
|
||||
| AIR-AGILEX-HQ | 5 |
|
||||
| robotwin2_abs_ee | 6 |
|
||||
| robotwin2_clean | 6 |
|
||||
| robocasa-human | 7 |
|
||||
| VLABench | 8 |
|
||||
| AGIBOT-challenge | 9 |
|
||||
| AIR-AGILEX | 10 |
|
||||
| AIRBOT | 18 |
|
||||
|
||||
### 3. Processor Steps
|
||||
|
||||
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
# Synthetic Data Generation Script - Summary
|
||||
|
||||
## ✅ What Was Created
|
||||
|
||||
### Main Script: `annotate_pgen.py` (717 lines)
|
||||
A production-ready script implementing the Hi-Robot synthetic data generation pipeline.
|
||||
|
||||
**Key Features:**
|
||||
- ✅ Loads LeRobot datasets with skill annotations
|
||||
- ✅ Generates synthetic user prompts and robot utterances using Qwen VLM
|
||||
- ✅ **Temporal sampling** - generates dialogue every N seconds (default: 1s)
|
||||
- ✅ Adds `task_index_high_level` feature to dataset parquets
|
||||
- ✅ Saves high-level tasks to `meta/tasks_high_level.parquet`
|
||||
- ✅ Exports debug JSONL for quality analysis
|
||||
- ✅ Supports both Qwen2-VL and Qwen3-VL models
|
||||
- ✅ Multi-view camera support
|
||||
- ✅ Episode-aware processing with automatic first-frame sampling
|
||||
- ✅ Modular architecture for easy extension
|
||||
|
||||
### Supporting Files Created
|
||||
|
||||
1. **`run_pgen.sh`** - Convenience script with sensible defaults
|
||||
2. **`README_PGEN.md`** - Comprehensive documentation with examples
|
||||
3. **`example_pgen_usage.md`** - Practical examples and performance estimates
|
||||
4. **`SAMPLING_DIAGRAM.md`** - Visual explanation of temporal sampling strategy
|
||||
5. **`PGEN_SUMMARY.md`** - This file
|
||||
|
||||
## 🚀 Key Innovation: Temporal Sampling
|
||||
|
||||
The script processes **ALL episodes** in the dataset efficiently via `--sample-interval`:
|
||||
|
||||
```bash
|
||||
# Instead of calling VLM for every frame (expensive):
|
||||
# 15,000 frames × VLM call = ~5 hours
|
||||
|
||||
# Generate dialogue every 1 second (efficient):
|
||||
python annotate_pgen.py --repo-id dataset --model qwen --sample-interval 1.0
|
||||
# 15,000 frames processed, only ~500 VLM calls (30x speedup!)
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Process ALL frames in ALL episodes (complete coverage)
|
||||
- Generate dialogue at sampled timepoints (e.g., every 1 second)
|
||||
- Propagate task indices to intermediate frames
|
||||
- Always sample first frame of each episode
|
||||
- All frames get labeled, but VLM is only called for samples
|
||||
- No dummy values or skipped episodes
|
||||
|
||||
**Benefits:**
|
||||
- 30-100x speedup depending on interval
|
||||
- Maintains temporal coherence
|
||||
- Reduces cost without losing quality
|
||||
- Configurable based on skill duration
|
||||
|
||||
## 📊 Efficiency Comparison
|
||||
|
||||
For a typical 15,000 frame dataset at 30 fps:
|
||||
|
||||
| Method | VLM Calls | Time | Cost |
|
||||
|--------|-----------|------|------|
|
||||
| Every frame | 15,000 | ~5 hours | $$$$ |
|
||||
| Every 0.5s | 1,000 | ~20 min | $$$ |
|
||||
| **Every 1s** (default) | **500** | **~10 min** | **$$** |
|
||||
| Every 2s | 250 | ~5 min | $ |
|
||||
|
||||
## 🎯 Usage
|
||||
|
||||
### Quick Test (5s sampling for fast iteration)
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 5.0 \
|
||||
--output-dir ./outputs/test_quick
|
||||
```
|
||||
|
||||
### Production Run (Recommended Settings)
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir ./outputs/full_pgen
|
||||
```
|
||||
|
||||
### High-Quality with Qwen3
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--sample-interval 0.5 \
|
||||
--temperature 0.6 \
|
||||
--output-dir ./outputs/high_quality
|
||||
```
|
||||
|
||||
## 📦 Output Structure
|
||||
|
||||
After running, you'll have:
|
||||
|
||||
```
|
||||
dataset_root/
|
||||
├── meta/
|
||||
│ ├── tasks_high_level.parquet # High-level tasks with prompts/utterances
|
||||
│ └── syn_annotations.jsonl # Debug: full context for each sample
|
||||
└── data/
|
||||
└── chunk-000/
|
||||
└── file-000.parquet # Updated with task_index_high_level
|
||||
```
|
||||
|
||||
**New feature added to all parquet files:**
|
||||
- `task_index_high_level` (int64): Links to tasks_high_level.parquet
|
||||
|
||||
## 🔧 All Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `--repo-id` / `--data-dir` | - | Dataset source |
|
||||
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model |
|
||||
| `--device` | cuda | Device to use |
|
||||
| `--dtype` | bfloat16 | Model precision |
|
||||
| `--temperature` | 0.7 | Sampling temperature |
|
||||
| **`--sample-interval`** | **1.0** | **Generate every N seconds (all episodes processed)** |
|
||||
| `--num-image-views-per-sample` | 1 | Number of cameras |
|
||||
| `--batch-size` | 1 | Batch size (currently unused) |
|
||||
| `--output-dir` | None | Output directory |
|
||||
| `--push-to-hub` | False | Push to HuggingFace |
|
||||
|
||||
## 🎨 Generated Data Format
|
||||
|
||||
Each sampled frame produces:
|
||||
|
||||
```json
|
||||
{
|
||||
"scenario_type": "specific_object",
|
||||
"response_type": "confirmation",
|
||||
"user_prompt": "Can you pick up the pink brick?",
|
||||
"robot_utterance": "Sure, I'll grab the pink lego brick.",
|
||||
"skill": "robot arm picks up pink lego brick",
|
||||
"episode_id": 0,
|
||||
"frame_index": 45,
|
||||
"timestamp": 1.5,
|
||||
"skill_history": ["robot arm moves towards pink lego brick"],
|
||||
"task_description": "pink lego brick into the transparent box"
|
||||
}
|
||||
```
|
||||
|
||||
**Scenario Types:**
|
||||
- specific_object, negative_task, situated_correction, implicit_request, constraint_based
|
||||
|
||||
**Response Types:**
|
||||
- confirmation, clarification, acknowledgment, constraint_acknowledgment
|
||||
|
||||
## 🔬 Code Architecture
|
||||
|
||||
```python
|
||||
# Main components (modular design)
|
||||
|
||||
class QwenPgen:
|
||||
"""VLM wrapper supporting Qwen2/3"""
|
||||
def call_qwen(images, prompt) -> dict
|
||||
|
||||
def construct_prompt(task, history, skill) -> str:
|
||||
"""Build contextual prompt with history"""
|
||||
|
||||
def annotate_sample(pgen, images, ...) -> dict:
|
||||
"""Generate dialogue for one sample"""
|
||||
|
||||
def generate_synthetic_data(dataset, pgen, ...) -> tuple:
|
||||
"""Process entire dataset with temporal sampling"""
|
||||
# Core sampling logic:
|
||||
# - Track last_sample_timestamp per episode
|
||||
# - Sample if time_elapsed >= sample_interval
|
||||
# - Always sample first frame of episodes
|
||||
# - Propagate task_index to intermediate frames
|
||||
|
||||
def main():
|
||||
"""CLI entrypoint with argparse"""
|
||||
```
|
||||
|
||||
## ✨ Next Steps
|
||||
|
||||
1. **Quick test with large interval:**
|
||||
```bash
|
||||
# Fast iteration - samples every 5 seconds
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /path/to/dataset \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 5.0 \
|
||||
--output-dir ./outputs/quick_test
|
||||
```
|
||||
|
||||
2. **Verify output quality:**
|
||||
```bash
|
||||
head outputs/quick_test/meta/syn_annotations.jsonl
|
||||
```
|
||||
|
||||
3. **Production run:**
|
||||
```bash
|
||||
# Standard 1 second sampling for production
|
||||
bash examples/dataset/run_pgen.sh
|
||||
```
|
||||
|
||||
4. **Use in training:**
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
ds = LeRobotDataset(repo_id="...", root="outputs/pgen_annotations")
|
||||
|
||||
# Access high-level task for each frame
|
||||
frame = ds[100]
|
||||
task_idx = frame["task_index_high_level"].item()
|
||||
```
|
||||
|
||||
## 📚 Documentation Files
|
||||
|
||||
- **`README_PGEN.md`**: Full API reference and troubleshooting
|
||||
- **`example_pgen_usage.md`**: Practical examples with performance estimates
|
||||
- **`SAMPLING_DIAGRAM.md`**: Visual explanation of temporal sampling
|
||||
- **`PGEN_SUMMARY.md`**: This overview document
|
||||
|
||||
## 🎯 Success Criteria
|
||||
|
||||
✅ Script generates synthetic dialogue using Qwen VLM
|
||||
✅ Adds `task_index_high_level` feature to dataset
|
||||
✅ Saves tasks to `tasks_high_level.parquet`
|
||||
✅ Implements efficient temporal sampling (30-100x speedup)
|
||||
✅ Handles episode boundaries correctly
|
||||
✅ Produces diverse interaction types (scenarios + responses)
|
||||
✅ Maintains temporal coherence within episodes
|
||||
✅ Includes comprehensive documentation and examples
|
||||
✅ Ready for production use on real datasets
|
||||
|
||||
## 💡 Key Takeaway
|
||||
|
||||
**The script processes ALL episodes with intelligent sampling:**
|
||||
- `--sample-interval` controls how often VLM is called (default: 1.0s)
|
||||
- ALL frames in ALL episodes get labeled (complete coverage)
|
||||
- Intermediate frames inherit from most recent sample (temporal coherence)
|
||||
- Achieves 30-100x speedup while maintaining quality
|
||||
- Adjust interval based on use case: 5.0s for testing, 1.0s for production, 0.5s for fine detail
|
||||
|
||||
This makes the synthetic data generation **practical, scalable, and complete** for real-world datasets!
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
# Synthetic Data Generation for Hierarchical Robot Policies
|
||||
|
||||
This directory contains `annotate_pgen.py`, a script for generating synthetic user prompts and robot utterances for hierarchical policy training using Vision-Language Models (VLMs).
|
||||
|
||||
## Overview
|
||||
|
||||
The script implements the synthetic data generation pipeline described in the Hi-Robot paper:
|
||||
|
||||
1. **Load** a LeRobot dataset with skill annotations (from `annotate.py`)
|
||||
2. **Generate** synthetic dialogue using Qwen VLM:
|
||||
- User prompts (ℓ_t): Natural requests that lead to specific skills
|
||||
- Robot utterances (u_t): Acknowledgments and clarifications
|
||||
3. **Save** results as a new dataset feature `task_index_high_level`
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. First, annotate your dataset with skills using `annotate.py`:
|
||||
|
||||
```bash
|
||||
python examples/dataset/annotate.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--video-key observation.images.base \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct
|
||||
```
|
||||
|
||||
This creates `meta/skills.json` with skill segmentation for each episode.
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir ./outputs/pgen_dataset
|
||||
```
|
||||
|
||||
**Note**: The script processes **all episodes** in the dataset. It generates dialogue every 1 second (`--sample-interval 1.0`) using temporal sampling. Frames between samples reuse the last generated dialogue. This makes the process efficient while ensuring complete dataset coverage.
|
||||
|
||||
### Advanced Options
|
||||
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--temperature 0.8 \
|
||||
--sample-interval 0.5 \
|
||||
--num-image-views-per-sample 2 \
|
||||
--output-dir ./outputs/pgen_dataset \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
This example uses a more powerful model and samples every 0.5 seconds for finer granularity.
|
||||
|
||||
### Fast Testing (larger interval)
|
||||
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 5.0 \
|
||||
--output-dir ./outputs/pgen_quick_test
|
||||
```
|
||||
|
||||
Use a larger interval (5.0 seconds) for rapid iteration during development. All episodes are still processed.
|
||||
|
||||
### Using Local Dataset
|
||||
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--output-dir ./outputs/pgen_dataset
|
||||
```
|
||||
|
||||
## Output Files
|
||||
|
||||
The script produces several outputs:
|
||||
|
||||
1. **`meta/tasks_high_level.parquet`**: High-level tasks with user prompts and robot utterances
|
||||
- Columns: task_index, user_prompt, robot_utterance, skill, scenario_type, response_type
|
||||
|
||||
2. **`meta/syn_annotations.jsonl`**: Debug file with all generated dialogues
|
||||
- One JSON object per line with full context for each frame
|
||||
|
||||
3. **Modified dataset**: New dataset with `task_index_high_level` feature added to all parquet files
|
||||
|
||||
## Scenario and Response Types
|
||||
|
||||
The generator produces diverse interaction types:
|
||||
|
||||
### Scenario Types
|
||||
- **specific_object**: Direct specification of objects/actions
|
||||
- **negative_task**: Instructions about what NOT to do
|
||||
- **situated_correction**: Adjustments based on current state
|
||||
- **implicit_request**: Implied needs without direct commands
|
||||
- **constraint_based**: Specific constraints or preferences
|
||||
|
||||
### Response Types
|
||||
- **confirmation**: Simple acknowledgment ("OK, I'll do X")
|
||||
- **clarification**: Seeking confirmation ("Just to confirm...")
|
||||
- **acknowledgment**: Action acknowledgment ("Got it, doing X")
|
||||
- **constraint_acknowledgment**: Acknowledging constraints ("Sure, I'll X while Y")
|
||||
|
||||
## Example Generated Data
|
||||
|
||||
```json
|
||||
{
|
||||
"episode_id": 0,
|
||||
"frame_index": 45,
|
||||
"timestamp": 2.5,
|
||||
"skill_current": "robot arm picks up pink lego brick",
|
||||
"skill_history": ["robot arm moves towards pink lego brick"],
|
||||
"task_description": "pink lego brick into the transparent box",
|
||||
"scenario_type": "specific_object",
|
||||
"response_type": "confirmation",
|
||||
"user_prompt": "Can you grab the pink brick?",
|
||||
"robot_utterance": "Sure, I'll pick up the pink lego brick."
|
||||
}
|
||||
```
|
||||
|
||||
## Accessing the Data
|
||||
|
||||
After running the script, access the synthetic data in your code:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
import pandas as pd
|
||||
|
||||
# Load modified dataset
|
||||
dataset = LeRobotDataset(repo_id="lerobot/svla_so101_pickplace_with_high_level_tasks")
|
||||
|
||||
# Access frame with high-level task
|
||||
frame = dataset[100]
|
||||
high_level_task_idx = frame["task_index_high_level"].item()
|
||||
|
||||
# Load high-level tasks
|
||||
tasks_df = pd.read_parquet(dataset.root / "meta" / "tasks_high_level.parquet")
|
||||
task_info = tasks_df.iloc[high_level_task_idx]
|
||||
|
||||
print(f"User prompt: {task_info['user_prompt']}")
|
||||
print(f"Robot utterance: {task_info['robot_utterance']}")
|
||||
print(f"Skill: {task_info['skill']}")
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The script is modular and extensible:
|
||||
|
||||
```python
|
||||
# Core components
|
||||
class QwenPgen:
|
||||
"""VLM wrapper for generation"""
|
||||
def call_qwen(images, prompt) -> dict
|
||||
|
||||
def construct_prompt(task, history, skill) -> str
|
||||
"""Build prompt for VLM"""
|
||||
|
||||
def annotate_sample(pgen, images, ...) -> dict
|
||||
"""Generate dialogue for one sample"""
|
||||
|
||||
def generate_synthetic_data(dataset, pgen, ...) -> tuple
|
||||
"""Process entire dataset"""
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `--repo-id` | - | HuggingFace dataset ID |
|
||||
| `--data-dir` | - | Local dataset path |
|
||||
| `--model` | Qwen/Qwen2-VL-7B-Instruct | VLM model name |
|
||||
| `--device` | cuda | Device (cuda/cpu) |
|
||||
| `--dtype` | bfloat16 | Model precision |
|
||||
| `--temperature` | 0.7 | Sampling temperature |
|
||||
| `--sample-interval` | 1.0 | Generate dialogue every N seconds (all episodes processed) |
|
||||
| `--num-image-views-per-sample` | 1 | Number of cameras |
|
||||
| `--output-dir` | None | Output directory |
|
||||
| `--push-to-hub` | False | Push to HuggingFace Hub |
|
||||
|
||||
## Sampling Strategy
|
||||
|
||||
The script uses **temporal sampling** to efficiently generate dialogue:
|
||||
|
||||
- **Default**: Generate dialogue every 1 second (`--sample-interval 1.0`)
|
||||
- **Efficiency**: If a dataset runs at 30fps, this samples ~3% of frames
|
||||
- **Propagation**: Frames between samples reuse the last generated task_index
|
||||
- **Episode-aware**: Always samples the first frame of each episode
|
||||
|
||||
### Example with 30 fps dataset:
|
||||
```bash
|
||||
# Sample every 1 second (every 30 frames)
|
||||
--sample-interval 1.0 # ~3,000 generations for a 100 episode dataset (3 sec/episode)
|
||||
|
||||
# Sample every 0.5 seconds (every 15 frames)
|
||||
--sample-interval 0.5 # ~6,000 generations (more granular)
|
||||
|
||||
# Sample every 2 seconds (every 60 frames)
|
||||
--sample-interval 2.0 # ~1,500 generations (more efficient)
|
||||
```
|
||||
|
||||
### Why sampling works:
|
||||
- Skills typically last 1-3 seconds
|
||||
- Dialogue doesn't need to change every frame
|
||||
- Reduces computational cost by 30-100x
|
||||
- Still provides good coverage for training
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Quick testing**: Use larger `--sample-interval` (e.g., 5.0 or 10.0) for rapid iteration
|
||||
2. **Monitor GPU**: VLM inference is memory-intensive
|
||||
3. **Check outputs**: Review `syn_annotations.jsonl` for quality
|
||||
4. **Adjust temperature**: Higher = more diverse, lower = more consistent
|
||||
5. **Multiple views**: Use `--num-image-views-per-sample 2+` for better context
|
||||
6. **Tune sampling**: Start with 1.0s, increase for speed (testing), decrease for granularity (production)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No skills.json found
|
||||
Run `annotate.py` first to generate skill annotations.
|
||||
|
||||
### Out of memory
|
||||
- Reduce batch size to 1
|
||||
- Use smaller model (Qwen2-VL-7B instead of Qwen3-VL-30B)
|
||||
- Process fewer samples at a time
|
||||
|
||||
### Poor quality generations
|
||||
- Adjust temperature (try 0.6-0.9)
|
||||
- Check that skills.json has good annotations
|
||||
- Ensure images are loading correctly
|
||||
|
||||
## Citation
|
||||
|
||||
Based on the Hi-Robot paper's synthetic data generation approach:
|
||||
```
|
||||
@article{hirobot2024,
|
||||
title={Hi-Robot: Hierarchical Robot Learning with Vision-Language Models},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
# Temporal Sampling Strategy Visualization
|
||||
|
||||
## How `--sample-interval` Works
|
||||
|
||||
### Example: 30 fps dataset, `--sample-interval 1.0` (1 second)
|
||||
|
||||
```
|
||||
Timeline (seconds): 0.0 0.5 1.0 1.5 2.0 2.5 3.0
|
||||
│ │ │ │ │ │ │
|
||||
Frames: 0───15───30───45───60───75───90───105──120──135──150
|
||||
│ │ │ │ │ │ │
|
||||
▼ ▼ ▼ ▼
|
||||
Sampled: YES NO YES NO YES NO YES
|
||||
│ │ │ │
|
||||
Task Index: [0]──────────────>[1]──────────────>[2]──────────────>[3]
|
||||
│ │ │ │
|
||||
VLM Called: ✓ Gen ✓ Gen ✓ Gen ✓ Gen
|
||||
dialogue dialogue dialogue dialogue
|
||||
│ │ │ │
|
||||
Frames 0-29 ─────┘ │ │ │
|
||||
get task 0 │ │ │
|
||||
│ │ │
|
||||
Frames 30-59 ────────────────────────┘ │ │
|
||||
get task 1 │ │
|
||||
│ │
|
||||
Frames 60-89 ──────────────────────────────────────────┘ │
|
||||
get task 2 │
|
||||
│
|
||||
Frames 90-119 ────────────────────────────────────────────────────────────┘
|
||||
get task 3
|
||||
```
|
||||
|
||||
## Comparison: Different Sampling Intervals
|
||||
|
||||
### `--sample-interval 2.0` (every 2 seconds)
|
||||
```
|
||||
Timeline: 0.0 1.0 2.0 3.0 4.0 5.0 6.0
|
||||
│ │ │ │ │ │ │
|
||||
Sampled: YES NO YES NO YES NO YES
|
||||
│ │ │ │
|
||||
Tasks: [0]───────────────>[1]───────────────>[2]───────────────>[3]
|
||||
|
||||
VLM Calls: 4 (fewer calls, faster but less granular)
|
||||
```
|
||||
|
||||
### `--sample-interval 1.0` (every 1 second) - **DEFAULT**
|
||||
```
|
||||
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
|
||||
│ │ │ │ │ │ │ │ │ │ │ │ │
|
||||
Sampled: YES NO YES NO YES NO YES NO YES NO YES NO YES
|
||||
│ │ │ │ │ │ │
|
||||
Tasks: [0]─────────>[1]─────────>[2]─────────>[3]─────────>[4]─────────>[5]─────>[6]
|
||||
|
||||
VLM Calls: 7 (balanced coverage and speed)
|
||||
```
|
||||
|
||||
### `--sample-interval 0.5` (every 0.5 seconds)
|
||||
```
|
||||
Timeline: 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0
|
||||
│ │ │ │ │ │ │ │ │ │ │ │ │
|
||||
Sampled: YES YES YES YES YES YES YES YES YES YES YES YES YES
|
||||
│ │ │ │ │ │ │ │ │ │ │ │ │
|
||||
Tasks: [0]─>[1]─>[2]─>[3]─>[4]─>[5]─>[6]─>[7]─>[8]─>[9]─>[10]>[11]>[12]
|
||||
|
||||
VLM Calls: 13 (high granularity, slower but more detailed)
|
||||
```
|
||||
|
||||
## Episode Boundaries
|
||||
|
||||
The script always samples the **first frame** of each episode:
|
||||
|
||||
```
|
||||
Episode 0 Episode 1 Episode 2
|
||||
├─────────────────────────────────┤├─────────────────────────────────┤├──────...
|
||||
│ ││ ││
|
||||
Frame: 0 30 60 90 120 130 160 190 220 250 260 290 320
|
||||
Time: 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0 3.0 4.0 0.0 1.0 2.0
|
||||
│ │ │ │ │ │ │ │ │ │ │ │ │
|
||||
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
|
||||
Sample:YES YES YES YES YES YES YES YES YES YES YES YES YES
|
||||
│ │ │ │ │ │ │ │ │ │ │ │ │
|
||||
Task: 0────1─────2─────3────4 5─────6─────7─────8────9 10────11───12
|
||||
|
||||
Note: Frames 0, 130, 260 are ALWAYS sampled (episode starts)
|
||||
Even if they're within the sample-interval window
|
||||
```
|
||||
|
||||
## Real-World Example: svla_so101_pickplace Dataset
|
||||
|
||||
Typical stats:
|
||||
- **Total episodes**: 50
|
||||
- **Avg episode length**: 300 frames (10 seconds at 30 fps)
|
||||
- **Total frames**: 15,000
|
||||
|
||||
### Without Sampling (every frame)
|
||||
```
|
||||
Frames processed: 15,000
|
||||
VLM calls: 15,000
|
||||
Time estimate: ~5 hours
|
||||
Unique tasks: ~12,000 (lots of duplicates)
|
||||
```
|
||||
|
||||
### With `--sample-interval 1.0` (every 1 second)
|
||||
```
|
||||
Frames processed: 15,000 ✓
|
||||
VLM calls: 500
|
||||
Time estimate: ~10 minutes
|
||||
Unique tasks: ~450 (meaningful variety)
|
||||
Efficiency gain: 30x faster
|
||||
```
|
||||
|
||||
### With `--sample-interval 2.0` (every 2 seconds)
|
||||
```
|
||||
Frames processed: 15,000 ✓
|
||||
VLM calls: 250
|
||||
Time estimate: ~5 minutes
|
||||
Unique tasks: ~220
|
||||
Efficiency gain: 60x faster
|
||||
```
|
||||
|
||||
## Key Points
|
||||
|
||||
1. **All frames get labeled**: Every frame gets a `task_index_high_level`
|
||||
2. **Only sampled frames call VLM**: Huge efficiency gain
|
||||
3. **Temporal coherence**: Nearby frames share the same task
|
||||
4. **Episode-aware**: Always samples episode starts
|
||||
5. **Configurable**: Adjust `--sample-interval` based on your needs
|
||||
|
||||
## Choosing Your Sampling Interval
|
||||
|
||||
| Use Case | Recommended Interval | Why |
|
||||
|----------|---------------------|-----|
|
||||
| Quick testing | 2.0s | Fastest iteration |
|
||||
| Standard training | 1.0s | Good balance |
|
||||
| High-quality dataset | 0.5s | Better coverage |
|
||||
| Fine-grained control | 0.33s | Very detailed |
|
||||
| Dense annotations | 0.1s | Nearly every frame |
|
||||
|
||||
**Rule of thumb**: Match your sampling interval to your typical skill duration.
|
||||
If skills last 1-3 seconds, sampling every 1 second captures each skill multiple times.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,143 +0,0 @@
|
||||
# Example: Synthetic Data Generation with Sampling
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Test with 100 frames and 1 second sampling
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--num-samples 100 \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir ./outputs/test_pgen
|
||||
```
|
||||
|
||||
**Expected behavior** (assuming 30 fps):
|
||||
- Total frames: 100
|
||||
- Frames sampled: ~4 (every 30 frames = 1 second)
|
||||
- Efficiency: 96% fewer VLM calls
|
||||
- Output: All 100 frames get `task_index_high_level`, but only 4 unique dialogues generated
|
||||
|
||||
### 2. Process full dataset with different sampling rates
|
||||
|
||||
#### Conservative (every 2 seconds)
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 2.0 \
|
||||
--output-dir ./outputs/pgen_2s
|
||||
```
|
||||
|
||||
#### Standard (every 1 second) - **RECOMMENDED**
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir ./outputs/pgen_1s
|
||||
```
|
||||
|
||||
#### Fine-grained (every 0.5 seconds)
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--sample-interval 0.5 \
|
||||
--output-dir ./outputs/pgen_0.5s
|
||||
```
|
||||
|
||||
## Performance Estimates
|
||||
|
||||
For a dataset with:
|
||||
- 100 episodes
|
||||
- 10 seconds per episode (average)
|
||||
- 30 fps
|
||||
- Total frames: 30,000
|
||||
|
||||
| Sampling Interval | Frames Sampled | % Sampled | Speedup | Time Estimate |
|
||||
|-------------------|----------------|-----------|---------|---------------|
|
||||
| Every frame (0.033s) | 30,000 | 100% | 1x | ~10 hours |
|
||||
| 0.5 seconds | 2,000 | 6.7% | 15x | ~40 min |
|
||||
| **1.0 seconds** | **1,000** | **3.3%** | **30x** | **~20 min** |
|
||||
| 2.0 seconds | 500 | 1.7% | 60x | ~10 min |
|
||||
|
||||
*Note: Times are approximate and depend on GPU, model size, and generation speed*
|
||||
|
||||
## Understanding the Output
|
||||
|
||||
### Console Output Example
|
||||
```
|
||||
[cyan]Generating synthetic data for 30000 frames...[/cyan]
|
||||
[cyan]Sampling interval: 1.0s (fps: 30)[/cyan]
|
||||
Generating synthetic dialogue: 100%|████████| 30000/30000 [20:15<00:00, 24.68it/s]
|
||||
[green]✓ Sampled 1000 frames out of 30000 (3.3%)[/green]
|
||||
[green]✓ Generated 450 unique high-level tasks[/green]
|
||||
```
|
||||
|
||||
### What happens:
|
||||
1. **Frame 0 (t=0.0s)**: Generate dialogue → Task index 0
|
||||
2. **Frames 1-29 (t=0.033s-0.967s)**: Reuse task index 0
|
||||
3. **Frame 30 (t=1.0s)**: Generate new dialogue → Task index 1
|
||||
4. **Frames 31-59 (t=1.033s-1.967s)**: Reuse task index 1
|
||||
5. And so on...
|
||||
|
||||
### Result:
|
||||
- Every frame has a `task_index_high_level`
|
||||
- Only sampled frames have unique dialogues generated
|
||||
- Intermediate frames inherit from the most recent sample
|
||||
- Maintains temporal coherence within episodes
|
||||
|
||||
## Checking Your Results
|
||||
|
||||
After running, verify the output:
|
||||
|
||||
```bash
|
||||
# Check the generated tasks
|
||||
python -c "
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
tasks = pd.read_parquet('outputs/test_pgen/meta/tasks_high_level.parquet')
|
||||
print(f'Total unique tasks: {len(tasks)}')
|
||||
print(f'Sample tasks:')
|
||||
print(tasks[['user_prompt', 'robot_utterance', 'skill']].head())
|
||||
"
|
||||
|
||||
# Check debug output
|
||||
head outputs/test_pgen/meta/syn_annotations.jsonl
|
||||
|
||||
# Load and verify dataset
|
||||
python -c "
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
ds = LeRobotDataset(repo_id='local_with_high_level_tasks',
|
||||
root='outputs/test_pgen')
|
||||
print(f'Dataset has {len(ds)} frames')
|
||||
print(f'Features: {list(ds.features.keys())}')
|
||||
assert 'task_index_high_level' in ds.features
|
||||
print('✓ task_index_high_level feature added successfully!')
|
||||
"
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Development/Testing
|
||||
```bash
|
||||
--sample-interval 2.0 # Fast iteration
|
||||
--num-samples 500 # Small subset
|
||||
```
|
||||
|
||||
### Production Training
|
||||
```bash
|
||||
--sample-interval 1.0 # Good coverage
|
||||
# Process all samples (no --num-samples)
|
||||
```
|
||||
|
||||
### High-Quality Dataset
|
||||
```bash
|
||||
--sample-interval 0.5 # Fine-grained
|
||||
--temperature 0.6 # More consistent
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct # Larger model
|
||||
```
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
||||
|
||||
model_id = "google/paligemma-3b-pt-224"
|
||||
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
breakpoint()
|
||||
prefix_output = model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
# prefix_output to be used for the language head
|
||||
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
@@ -1,58 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
# import make_pre_post_processors
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
)
|
||||
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
|
||||
# rename map --rename_map='{
|
||||
# "observation.images.side": "observation.images.base_0_rgb",
|
||||
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
# }'
|
||||
rename_map = {
|
||||
"observation.images.side": "observation.images.base_0_rgb",
|
||||
"observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
}
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=rename_map,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
batch = pre_processor(batch)
|
||||
|
||||
# Test training forward pass
|
||||
policy.train()
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Training loss: {loss_dict}")
|
||||
|
||||
# Test inference
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
actions = policy.predict_action_chunk(batch)
|
||||
print(f"Predicted actions shape: {actions.shape}")
|
||||
@@ -1,23 +0,0 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
print(batch.keys())
|
||||
print(batch['task_index_high_level'].shape)
|
||||
print(batch['task_index_high_level'])
|
||||
print(batch['user_prompt'][0])
|
||||
print(batch['robot_utterance'][0])
|
||||
print(batch['task'][0])
|
||||
breakpoint()
|
||||
@@ -1,334 +0,0 @@
|
||||
Generate annotate_pgen.py using Qwen for synthetic data generation
|
||||
|
||||
You are writing a Python script called annotate_pgen.py.
|
||||
This script generates synthetic user prompts (ℓ_t) and robot utterances (u_t) for Hi Robot–style hierarchical policy training, using Qwen 3vl as the generator model (pgen).
|
||||
|
||||
SCRIPT PURPOSE
|
||||
|
||||
The script must:
|
||||
|
||||
Load Dlabeled which is a LeRobot Dataset that has been annotate using the annotate.py script, which contains:
|
||||
|
||||
images: list of image paths at time t
|
||||
|
||||
skill_current: the annotated skill label (ℓ̂_t)
|
||||
|
||||
skill_history: list of previous skill labels (ℓ̂₀ … ℓ̂_{t−1}), those where annotated, and you can find details on them stored in teh dataset inside the the DATA_PATH/meta/skills.json
|
||||
|
||||
you will find something like
|
||||
|
||||
{
|
||||
"coarse_description": "pink lego brick into the transparent box",
|
||||
"skill_to_task_index": {
|
||||
"robot arm picks up pink lego brick": 19,
|
||||
"robot arm approaches transparent box": 3,
|
||||
"robot arm retracts from transparent box": 28,
|
||||
"robot arm moves towards pink lego brick": 12,
|
||||
"robot arm releases red lego brick into box": 26,
|
||||
"robot arm releases red lego brick into transparent box": 27,
|
||||
"robot arm closes gripper to pick up the pink lego brick": 5,
|
||||
"robot arm lifts the pink lego brick": 7,
|
||||
etc..
|
||||
},
|
||||
"episodes": {
|
||||
"0": {
|
||||
"episode_index": 0,
|
||||
"description": "pink lego brick into the transparent box",
|
||||
"skills": [
|
||||
{
|
||||
"name": "robot arm moves towards pink lego brick",
|
||||
"start": 0.0,
|
||||
"end": 1.8
|
||||
},
|
||||
{
|
||||
"name": "robot arm picks up pink lego brick",
|
||||
"start": 1.8,
|
||||
"end": 3.1
|
||||
},
|
||||
{
|
||||
"name": "robot arm moves towards transparent box",
|
||||
"start": 3.1,
|
||||
"end": 5.5
|
||||
},
|
||||
{
|
||||
"name": "robot arm releases pink lego brick into transparent box",
|
||||
"start": 5.5,
|
||||
"end": 7.0
|
||||
},
|
||||
{
|
||||
"name": "robot arm retracts from transparent box",
|
||||
"start": 7.0,
|
||||
"end": 10.1
|
||||
}
|
||||
]
|
||||
},
|
||||
"1": {
|
||||
"episode_index": 1,
|
||||
"description": "pink lego brick into the transparent box",
|
||||
"skills": [
|
||||
{
|
||||
"name": "robot arm moves towards red lego brick",
|
||||
"start": 0.0,
|
||||
"end": 1.2
|
||||
},
|
||||
{
|
||||
"name": "robot arm picks up red lego brick",
|
||||
"start": 1.2,
|
||||
"end": 2.0
|
||||
},
|
||||
{
|
||||
"name": "robot arm moves towards transparent box",
|
||||
"start": 2.0,
|
||||
"end": 3.8
|
||||
},
|
||||
{
|
||||
"name": "robot arm places red lego brick into transparent box",
|
||||
"start": 3.8,
|
||||
"end": 5.0
|
||||
},
|
||||
{
|
||||
"name": "robot arm moves away from transparent box",
|
||||
"start": 5.0,
|
||||
"end": 8.9
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
notice how task_description: is a high-level description (e.g., "make a sandwich") stored in description for each episode
|
||||
|
||||
For each sample, call Qwen VLM to generate:
|
||||
|
||||
synthetic user prompt ℓ_t
|
||||
|
||||
synthetic robot response u_t
|
||||
|
||||
Save results to D_syn in Parquet format insdie DATA_PATH/meta/tasks.parquet ; note tasks.parquet already contains the other tasks, so you need to update
|
||||
|
||||
Should be modular, clean, easy to extend, with:
|
||||
|
||||
a PGEN_PROMPT_TEMPLATE
|
||||
|
||||
a construct_prompt() method
|
||||
|
||||
a call_qwen() method
|
||||
|
||||
a annotate_sample() method
|
||||
|
||||
a CLI entrypoint (if __name__ == "__main__":)
|
||||
|
||||
📦 INPUT FORMAT (Dlabeled)
|
||||
|
||||
The script should expect Dlabeled as a .jsonl file where each line has:
|
||||
|
||||
{
|
||||
"episode_id": "ep_001",
|
||||
"t": 37,
|
||||
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
|
||||
"skill_current": "pick up the KitKat",
|
||||
"skill_history": ["open fridge", "pick up lettuce", "place lettuce"],
|
||||
"task_description": "making a sandwich"
|
||||
}
|
||||
|
||||
📤 OUTPUT FORMAT (D_syn)
|
||||
|
||||
Each line of synthetically generated data should be:
|
||||
|
||||
{
|
||||
"episode_id": "ep_001",
|
||||
"t": 37,
|
||||
"images": ["path/to/cam0_t.jpg", "path/to/cam1_t.jpg"],
|
||||
"skill_current": "pick up the KitKat",
|
||||
"skill_history": [...],
|
||||
"user_prompt": "Can you grab me something sweet?",
|
||||
"robot_utterance": "Sure, I can pick up the KitKat.",
|
||||
"task_description": "making a sandwich"
|
||||
}
|
||||
|
||||
|
||||
Store as syn_annotations.jsonl. for debugging
|
||||
|
||||
🧠 pgen MODEL (Qwen) REQUIREMENTS
|
||||
|
||||
Use HuggingFace Transformers:
|
||||
|
||||
Qwen/Qwen2-VL-7B-Instruct (or any Qwen2-VL Vision-Language model available)
|
||||
|
||||
Use the image + text chat interface
|
||||
|
||||
Vision inputs should be loaded with PIL
|
||||
|
||||
Use a single forward pass that outputs BOTH ℓ_t and u_t in a structured JSON
|
||||
|
||||
📝 PROMPT FORMAT FOR pgen
|
||||
|
||||
Create a template like:
|
||||
|
||||
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
||||
|
||||
You will receive:
|
||||
- A list of images showing the current robot scene.
|
||||
- The high-level task: {task_description}
|
||||
- Previous skill steps completed: {skill_history}
|
||||
- The next skill to be performed by the robot: {skill_current}
|
||||
|
||||
Generate two things in JSON:
|
||||
1. "user_prompt": a natural-sounding user request that logically leads to the robot performing the skill "{skill_current}" given the task and history.
|
||||
2. "robot_utterance": a natural robot reply acknowledging or clarifying the request.
|
||||
|
||||
The responses must be grounded in the visual scene, the task, and the skill history.
|
||||
|
||||
Respond ONLY in JSON:
|
||||
{
|
||||
"user_prompt": "...",
|
||||
"robot_utterance": "..."
|
||||
}
|
||||
|
||||
This resposne will have a corresponsing task_index, and the task will be saved in task.parqeut and you must update each dataset parquet in for example /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace/data/chunk-000/
|
||||
file-000.parquet to include this new feature called task_index_high_level consider udpatign the metadata in info.json as well
|
||||
📌 LOGIC REQUIRED
|
||||
construct_prompt(sample)
|
||||
|
||||
Loads sample dict
|
||||
|
||||
Inserts:
|
||||
|
||||
task_description
|
||||
|
||||
skill_history
|
||||
|
||||
skill_current
|
||||
|
||||
Returns a full text prompt string
|
||||
|
||||
call_qwen(images, prompt)
|
||||
|
||||
Loads images into Qwen-VL multimodal input format
|
||||
|
||||
Calls model.generate
|
||||
|
||||
Parses JSON output
|
||||
|
||||
annotate_sample(sample)
|
||||
|
||||
Builds prompt
|
||||
|
||||
Calls Qwen
|
||||
|
||||
Returns augmented sample with user_prompt + robot_utterance
|
||||
|
||||
🚀 CLI Usage
|
||||
|
||||
The script should run as:
|
||||
|
||||
python annotate_pgen.py \
|
||||
--output-dir PATH \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--batch-size 1
|
||||
|
||||
|
||||
Include arguments via argparse.
|
||||
|
||||
🔧 OTHER REQUIREMENTS
|
||||
|
||||
Use tqdm for progress bars
|
||||
|
||||
Log errors gracefully and continue
|
||||
|
||||
Support GPU acceleration (device="cuda")
|
||||
|
||||
Cache model loading so it's not reloaded every call
|
||||
|
||||
Make the prompt deterministic but allow temperature parameter
|
||||
|
||||
Add a flag --num-image-views-per-sample
|
||||
|
||||
Add automatic JSON parsing with helpful error messages
|
||||
|
||||
🎯 FINAL DELIVERABLE
|
||||
|
||||
Cursor must now generate:
|
||||
A full Python file named annotate_pgen.py implementing the above functionality end-to-end.
|
||||
|
||||
It should be production-ready, runnable on real data, cleanly structured, and easy to modify.
|
||||
|
||||
|
||||
from the paper:
|
||||
Next, we use a large vision-language model (VLM) pgen
|
||||
to produce synthetic user prompts and interjections ℓt,
|
||||
and corresponding robot utterance ut. Given Dlabeled, we
|
||||
prompt pgen with both the visual context I1
|
||||
t ,...,In
|
||||
t and the
|
||||
skill labelˆ
|
||||
ℓt (e.g., pick up the lettuce). pgen then imag-
|
||||
ines an appropriate interaction that might have led toˆ
|
||||
ℓt in a
|
||||
real user interaction: it generates possible user prompts ℓt
|
||||
(e.g., “Can you add some lettuce for me?”) along with the
|
||||
robot’s verbal responses and clarifications ut. We detail the
|
||||
A. Synthetic Data Generation
|
||||
A.1. Scenario and Response Categorization
|
||||
To ensure the quality and diversity of the synthetic data,
|
||||
we incorporate structured scenario classification and re-
|
||||
sponse categorization into the prompt design for pgen, fol-
|
||||
lowing (Stephan et al., 2024). Specifically, we classify
|
||||
interactions into different scenario types, such as nega-
|
||||
tive task (where the user instructs the robot what not to
|
||||
do), situated correction (where the user adjusts an earlier
|
||||
command based on the evolving task state), and specific
|
||||
constraint (where the user specifies particular constraints,
|
||||
such as dietary preferences). In addition, we categorize
|
||||
the robot’s responses into types such as simple confirma-
|
||||
tions, clarifications, and error handling. These classifica-
|
||||
tions guide the generation process to ensure a broad range
|
||||
of user-robot interactions.
|
||||
A.2. Prompt Construction for Contextual Grounding
|
||||
In prompt P, we include a detailed description of the task
|
||||
(e.g., bussing a table, making a sandwich, grocery shop-
|
||||
ping) and instruct the model to ground responses in visual
|
||||
observations and prior context. A key advantage of lever-
|
||||
aging large pretrained VLMs is their ability to incorporate
|
||||
world knowledge when generating interactions. For in-
|
||||
stance, the model can infer dietary constraints when gener-
|
||||
ating prompts for sandwich-making, producing user com-
|
||||
mands such as “Can you make a sandwich for me? I’m
|
||||
lactose intolerant” and an appropriate robot response like
|
||||
“Sure, I won’t put cheese on it.” Similarly, it can reason
|
||||
over ambiguous or implicit requests, such as inferring that
|
||||
“I want something sweet” in a grocery shopping scenario
|
||||
should lead to suggestions like chocolate or candy.
|
||||
To maintain consistency in multi-step tasks, we condition
|
||||
pgen on prior skill labels within an episodeˆ
|
||||
ˆ
|
||||
ℓ0,...,
|
||||
ℓt−1,
|
||||
allowing it to generate coherent user commands that
|
||||
account for past actions. For instance, if the robot
|
||||
has already placed lettuce and tomato on a sandwich,
|
||||
the generated user prompt might request additional in-
|
||||
gredients that logically follow. This ensures that the
|
||||
synthetic interactions reflect realistic task progression
|
||||
rather than isolated commands. As such, we leverage
|
||||
ˆ
|
||||
ˆ
|
||||
ˆ
|
||||
pgen(ℓt,ut|I1
|
||||
t ,...,In
|
||||
t ,
|
||||
ℓ0,...,
|
||||
ℓt−1,
|
||||
ℓt,P) to produce a richer,
|
||||
more diverse synthetic dataset Dsyn that provides mean-
|
||||
ingful supervision for training our high-level policy.
|
||||
While in this work we generate a separate Dsyn and train
|
||||
a separate high-level policy for each task (e.g., sandwich
|
||||
making vs. table cleaning) for clarity and ease of bench-
|
||||
marking, the architecture is readily amenable to a unified
|
||||
multi-task formulation. In principle, the same hierarchical
|
||||
approach could be used to train a single high-level policy
|
||||
across a multitude of tasks, facilitating knowledge transfer
|
||||
|
||||
|
||||
The result should be a new LeRobotDataset with a new feature called task_index_high_level inside each dataset parquet
|
||||
@@ -1,10 +0,0 @@
|
||||
# python examples/dataset/annotate.py \
|
||||
# --repo-id lerobot/svla_so101_pickplace \
|
||||
# --video-key observation.images.side \
|
||||
# --model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
|
||||
python examples/dataset/annotate.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--video-key observation.images.side \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--episodes 3 5 7 44
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Example script to run synthetic data generation with Qwen VLM
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="lerobot/svla_so101_pickplace"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
|
||||
|
||||
# Run synthetic data generation (processes ALL episodes)
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id "$REPO_ID" \
|
||||
--model "$MODEL" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--temperature "$TEMPERATURE" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--sample-interval "$SAMPLE_INTERVAL" \
|
||||
--num-image-views-per-sample 1
|
||||
|
||||
# For faster testing, increase sample interval:
|
||||
# --sample-interval 5.0 # Samples every 5 seconds (much faster)
|
||||
|
||||
# To push to hub after generation:
|
||||
# Add --push-to-hub flag
|
||||
|
||||
# Efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
|
||||
@@ -1,802 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM Subtask Annotation using local GPU (Qwen3-VL).
|
||||
|
||||
This script implements the annotation approach from the SARM paper using local GPU inference:
|
||||
"SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation"
|
||||
Paper: https://arxiv.org/pdf/2509.25358
|
||||
|
||||
What it does:
|
||||
1. Takes videos from a LeRobot dataset
|
||||
2. Uses Qwen3-VL running locally on GPU to identify when subtasks occur
|
||||
3. Saves subtask timestamps to the dataset metadata
|
||||
4. Optionally pushes the annotated dataset to HuggingFace Hub
|
||||
|
||||
SARM trains reward models that predict:
|
||||
- Stage: Which subtask is currently being executed (discrete classification)
|
||||
- Progress: How far along the subtask we are (continuous 0-1)
|
||||
|
||||
Supports three annotation modes:
|
||||
1. No annotations (no args): Auto-creates single sparse "task" stage covering full episode.
|
||||
Use with SARM config annotation_mode="single_stage" for simple tasks.
|
||||
|
||||
2. Dense-only (--dense-only --dense-subtasks): Dense annotations from VLM, auto-generated
|
||||
single sparse "task" stage. Use with annotation_mode="dense_only".
|
||||
|
||||
3. Dual mode (--sparse-subtasks + --dense-subtasks): Both sparse and dense annotations
|
||||
from VLM. Use with annotation_mode="dual".
|
||||
|
||||
Requirements:
|
||||
- GPU with sufficient VRAM (16GB+ recommended for 30B model)
|
||||
- `pip install transformers, torch, qwen-vl-utils`
|
||||
|
||||
Run with:
|
||||
```bash
|
||||
python examples/dataset_annotation/subtask_annotation.py \
|
||||
--repo-id your-username/your-dataset \
|
||||
--sparse-subtasks "Do ..." \
|
||||
--dense-subtasks "Do task 1, Do task 2, Do task 3" \
|
||||
--video-key observation.images.base \
|
||||
--push-to-hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import pandas as pd
|
||||
import torch
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from rich.console import Console
|
||||
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
Subtask,
|
||||
SubtaskAnnotation,
|
||||
Timestamp,
|
||||
compute_temporal_proportions,
|
||||
)
|
||||
|
||||
|
||||
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||||
subtask_str = "\n".join([f" - {name}" for name in subtask_list])
|
||||
|
||||
return textwrap.dedent(f"""\
|
||||
# Role
|
||||
You are a Robotics Vision System specializing in temporal action localization for robot manipulation. Your job is to segment a single demonstration video into distinct, non-overlapping atomic actions from a fixed subtask list.
|
||||
|
||||
# Subtask Label Set (Closed Vocabulary)
|
||||
You must strictly identify the video segments using ONLY the following labels. Do not create new labels or modify existing ones:
|
||||
|
||||
[
|
||||
{subtask_str}
|
||||
]
|
||||
|
||||
The video shows one successful execution of all subtasks in a logical order.
|
||||
|
||||
# Ground-Truth Semantics (Very Important)
|
||||
Use **visual state changes** to define when a subtask starts and ends. Do NOT assume equal durations for the subtasks.
|
||||
|
||||
- A subtask **starts** at the first frame where the robot's motion clearly initiates that subtask.
|
||||
- A subtask **ends** at the first frame where that specific action is visually completed and the manipulated object reaches a temporary, stable configuration.
|
||||
|
||||
If there are short pauses or micro-motions that don't clearly correspond to a new subtask, they belong to the **current** subtask.
|
||||
|
||||
# Hard Constraints & Logic
|
||||
1. **Continuous Coverage (No Gaps):**
|
||||
- The entire video duration from "00:00" to the final timestamp must be covered by subtasks.
|
||||
- There can be no gaps between subtasks.
|
||||
- If there is any idle or ambiguous time between clear actions, extend the *preceding* subtask to cover it.
|
||||
|
||||
2. **Boundary Consistency:**
|
||||
- The `"end"` timestamp of one subtask must be exactly equal to the `"start"` timestamp of the next subtask.
|
||||
- Boundaries must coincide with a real visual state transition, not just a convenient time split.
|
||||
|
||||
3. **Chronological Order, One Occurrence Each:**
|
||||
- This is a single successful demonstration.
|
||||
- Each subtask from the vocabulary appears **exactly once**, in the correct logical order.
|
||||
- **Durations may be very different** between subtasks. Never assume they are similar lengths. Base all boundaries only on the video.
|
||||
|
||||
4. **Reject Uniform Segmentation (Important):**
|
||||
- Do NOT simply divide the video into equal or nearly equal time chunks.
|
||||
- If your boundaries would result in subtasks with similar durations (e.g. all around 5 seconds), treat this as evidence that your segmentation is wrong and refine the boundaries.
|
||||
- Only use nearly equal durations if the video truly shows each subtask taking the same amount of time (this is very rare).
|
||||
|
||||
5. **Timestamps:**
|
||||
- Timestamps must be in `"MM:SS"` format.
|
||||
- The first subtask always starts at `"00:00"`.
|
||||
- The last subtask ends at the final visible frame of the video.
|
||||
|
||||
# Step 1 — Textual Timeline (must do this first)
|
||||
First, write a extensive and detailed textual timeline describing what happens in the video with approximate timestamps.
|
||||
For each subtask, include:
|
||||
- its name
|
||||
- an approximate start and end time,
|
||||
- an description of the visual event at the boundary (e.g. "shirt fully folded to the left", "robot rotates folded shirt 90 degrees").
|
||||
|
||||
Format this as a bullet list.
|
||||
|
||||
# Step 2 — JSON Output (final answer)
|
||||
After the textual timeline, output **only** valid JSON with this structure.
|
||||
The JSON **must** be consistent with the textual timeline above:
|
||||
|
||||
{{
|
||||
"subtasks": [
|
||||
{{
|
||||
"name": "EXACT_NAME_FROM_LIST",
|
||||
"timestamps": {{
|
||||
"start": "MM:SS",
|
||||
"end": "MM:SS"
|
||||
}}
|
||||
}},
|
||||
{{
|
||||
"name": "EXACT_NAME_FROM_LIST",
|
||||
"timestamps": {{
|
||||
"start": "MM:SS",
|
||||
"end": "MM:SS"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Do not add any extra keys to the JSON.
|
||||
""")
|
||||
|
||||
|
||||
class VideoAnnotator:
|
||||
"""Annotates robot manipulation videos using local Qwen3-VL model on GPU"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subtask_list: list[str],
|
||||
model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
model: "Qwen3VLMoeForConditionalGeneration | None" = None,
|
||||
processor: "AutoProcessor | None" = None,
|
||||
):
|
||||
"""
|
||||
Initialize the video annotator with local model.
|
||||
|
||||
Args:
|
||||
subtask_list: List of allowed subtask names (for consistency)
|
||||
model_name: Hugging Face model name (default: Qwen/Qwen3-VL-30B-A3B-Instruct)
|
||||
device: Device to use (cuda, cpu)
|
||||
torch_dtype: Data type for model (bfloat16, float16, float32)
|
||||
model: Pre-loaded model instance (optional, to share between annotators)
|
||||
processor: Pre-loaded processor instance (optional, to share between annotators)
|
||||
"""
|
||||
self.subtask_list = subtask_list
|
||||
self.prompt = create_sarm_prompt(subtask_list)
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
|
||||
# Use provided model/processor or load new ones
|
||||
if model is not None and processor is not None:
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.console.print(f"[green]✓ Using shared model on {device}[/green]")
|
||||
else:
|
||||
self.console.print(f"[cyan]Loading model: {model_name}...[/cyan]")
|
||||
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
def extract_episode_segment(
|
||||
self, file_path: Path, start_timestamp: float, end_timestamp: float, target_fps: int = 1
|
||||
) -> Path:
|
||||
"""
|
||||
Extract a specific episode segment from concatenated video.
|
||||
Uses minimal compression to preserve quality for local inference.
|
||||
|
||||
Args:
|
||||
file_path: Path to the concatenated video file
|
||||
start_timestamp: Starting timestamp in seconds (within this video file)
|
||||
end_timestamp: Ending timestamp in seconds (within this video file)
|
||||
target_fps: Target FPS (default: 1 for faster processing)
|
||||
|
||||
Returns:
|
||||
Path to extracted video file
|
||||
"""
|
||||
# Create temporary file for extracted video
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
tmp_file.close()
|
||||
|
||||
try:
|
||||
# Check if ffmpeg is available
|
||||
subprocess.run(
|
||||
["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
raise RuntimeError("ffmpeg not found, cannot extract episode segment") from e
|
||||
|
||||
try:
|
||||
# Calculate duration
|
||||
duration = end_timestamp - start_timestamp
|
||||
|
||||
self.console.print(
|
||||
f"[cyan]Extracting episode: {start_timestamp:.1f}s-{end_timestamp:.1f}s ({duration:.1f}s)[/cyan]"
|
||||
)
|
||||
|
||||
# Use ffmpeg to extract segment with minimal quality loss
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
str(file_path),
|
||||
"-ss",
|
||||
str(start_timestamp),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-r",
|
||||
str(target_fps),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-an",
|
||||
"-y",
|
||||
str(tmp_path),
|
||||
]
|
||||
|
||||
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
||||
|
||||
# Verify the output file was created and is not empty
|
||||
if not tmp_path.exists() or tmp_path.stat().st_size == 0:
|
||||
self.console.print("[red]✗ Video extraction failed (0 bytes) - skipping episode[/red]")
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
raise RuntimeError("FFmpeg produced empty video file")
|
||||
|
||||
# Show extraction results
|
||||
file_size_mb = tmp_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
# Fail if file is too small (< 100KB likely means extraction failed)
|
||||
if file_size_mb < 0.1:
|
||||
self.console.print(
|
||||
f"[red]✗ Extracted video too small ({file_size_mb:.2f}MB) - skipping episode[/red]"
|
||||
)
|
||||
tmp_path.unlink()
|
||||
raise RuntimeError(f"Video extraction produced invalid file ({file_size_mb:.2f}MB)")
|
||||
|
||||
self.console.print(f"[green]✓ Extracted: {file_size_mb:.1f}MB ({target_fps} FPS)[/green]")
|
||||
|
||||
return tmp_path
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"ffmpeg failed ({e})") from e
|
||||
|
||||
def annotate(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
fps: int,
|
||||
start_timestamp: float = 0.0,
|
||||
end_timestamp: float | None = None,
|
||||
max_retries: int = 3,
|
||||
) -> SubtaskAnnotation:
|
||||
"""Annotate a video segment using local GPU."""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if end_timestamp is None:
|
||||
cap = cv2.VideoCapture(str(file_path))
|
||||
end_timestamp = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) / (cap.get(cv2.CAP_PROP_FPS) or 1)
|
||||
cap.release()
|
||||
|
||||
duration = end_timestamp - start_timestamp
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
|
||||
extracted_path = self.extract_episode_segment(file_path, start_timestamp, end_timestamp, 1)
|
||||
is_extracted = extracted_path != file_path
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": self.prompt}]},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video", "video": str(extracted_path), "fps": 1.0},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Video is {duration_str} (~{duration:.1f}s). Follow instructions.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
# Extract JSON
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
return SubtaskAnnotation.model_validate(json.loads(response))
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
return SubtaskAnnotation.model_validate(json.loads(match.group()))
|
||||
raise ValueError("No JSON found")
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
raise RuntimeError(f"Failed after {max_retries} attempts") from e
|
||||
time.sleep(1)
|
||||
finally:
|
||||
if is_extracted and extracted_path.exists():
|
||||
extracted_path.unlink()
|
||||
|
||||
|
||||
def display_annotation(
|
||||
annotation: SubtaskAnnotation, console: Console, episode_idx: int, fps: int, prefix: str = ""
|
||||
):
|
||||
"""Display annotation summary."""
|
||||
subtask_summary = ", ".join(
|
||||
f"{s.name}({s.timestamps.start}-{s.timestamps.end})" for s in annotation.subtasks
|
||||
)
|
||||
console.print(
|
||||
f"[green]Episode {episode_idx} {prefix}: {len(annotation.subtasks)} subtasks - {subtask_summary}[/green]"
|
||||
)
|
||||
|
||||
|
||||
def timestamp_to_seconds(timestamp: str) -> float:
|
||||
"""Convert MM:SS or SS timestamp to seconds"""
|
||||
parts = timestamp.split(":")
|
||||
if len(parts) == 2:
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
else:
|
||||
return int(parts[0])
|
||||
|
||||
|
||||
def save_annotations_to_dataset(
|
||||
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
|
||||
):
|
||||
"""Save annotations to LeRobot dataset parquet format."""
|
||||
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH, load_episodes
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
return
|
||||
|
||||
episodes_df = episodes_dataset.to_pandas()
|
||||
cols = [
|
||||
f"{prefix}_{c}"
|
||||
for c in [
|
||||
"subtask_names",
|
||||
"subtask_start_times",
|
||||
"subtask_end_times",
|
||||
"subtask_start_frames",
|
||||
"subtask_end_frames",
|
||||
]
|
||||
]
|
||||
for col in cols:
|
||||
episodes_df[col] = None
|
||||
|
||||
for ep_idx, ann in annotations.items():
|
||||
if ep_idx >= len(episodes_df):
|
||||
continue
|
||||
names, starts, ends, start_frames, end_frames = [], [], [], [], []
|
||||
for s in ann.subtasks:
|
||||
names.append(s.name)
|
||||
st, et = timestamp_to_seconds(s.timestamps.start), timestamp_to_seconds(s.timestamps.end)
|
||||
starts.append(st)
|
||||
ends.append(et)
|
||||
start_frames.append(int(st * fps))
|
||||
end_frames.append(int(et * fps))
|
||||
episodes_df.at[ep_idx, cols[0]] = names
|
||||
episodes_df.at[ep_idx, cols[1]] = starts
|
||||
episodes_df.at[ep_idx, cols[2]] = ends
|
||||
episodes_df.at[ep_idx, cols[3]] = start_frames
|
||||
episodes_df.at[ep_idx, cols[4]] = end_frames
|
||||
|
||||
# Group by file and write
|
||||
for ep_idx in episodes_df.index:
|
||||
key = (
|
||||
episodes_df.loc[ep_idx, "meta/episodes/chunk_index"],
|
||||
episodes_df.loc[ep_idx, "meta/episodes/file_index"],
|
||||
)
|
||||
path = dataset_path / DEFAULT_EPISODES_PATH.format(chunk_index=key[0], file_index=key[1])
|
||||
if path.exists():
|
||||
file_df = pd.read_parquet(path)
|
||||
for col in cols + (
|
||||
[
|
||||
"subtask_names",
|
||||
"subtask_start_times",
|
||||
"subtask_end_times",
|
||||
"subtask_start_frames",
|
||||
"subtask_end_frames",
|
||||
]
|
||||
if prefix == "sparse"
|
||||
else []
|
||||
):
|
||||
if col not in file_df.columns:
|
||||
file_df[col] = None
|
||||
if ep_idx in annotations:
|
||||
for col in cols:
|
||||
file_df.at[ep_idx, col] = episodes_df.loc[ep_idx, col]
|
||||
if prefix == "sparse": # Legacy columns
|
||||
for i, legacy in enumerate(
|
||||
[
|
||||
"subtask_names",
|
||||
"subtask_start_times",
|
||||
"subtask_end_times",
|
||||
"subtask_start_frames",
|
||||
"subtask_end_frames",
|
||||
]
|
||||
):
|
||||
file_df.at[ep_idx, legacy] = episodes_df.loc[ep_idx, cols[i]]
|
||||
file_df.to_parquet(path, engine="pyarrow", compression="snappy")
|
||||
|
||||
|
||||
def generate_auto_sparse_annotations(
|
||||
dataset: LeRobotDataset, episode_indices: list[int], video_key: str
|
||||
) -> dict[int, SubtaskAnnotation]:
|
||||
"""Auto-generate single 'task' stage annotations for all episodes."""
|
||||
annotations = {}
|
||||
for ep_idx in episode_indices:
|
||||
start = float(dataset.meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
|
||||
end = float(dataset.meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
|
||||
duration = end - start
|
||||
end_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
annotations[ep_idx] = SubtaskAnnotation(
|
||||
subtasks=[Subtask(name="task", timestamps=Timestamp(start="00:00", end=end_str))]
|
||||
)
|
||||
return annotations
|
||||
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
|
||||
"""Load annotations from LeRobot dataset parquet files."""
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
if not episodes_dataset or len(episodes_dataset) == 0:
|
||||
return {}
|
||||
|
||||
col_names = f"{prefix}_subtask_names"
|
||||
col_start = f"{prefix}_subtask_start_times"
|
||||
col_end = f"{prefix}_subtask_end_times"
|
||||
|
||||
# Fall back to legacy columns for sparse
|
||||
if col_names not in episodes_dataset.column_names:
|
||||
if prefix == "sparse" and "subtask_names" in episodes_dataset.column_names:
|
||||
col_names, col_start, col_end = "subtask_names", "subtask_start_times", "subtask_end_times"
|
||||
else:
|
||||
return {}
|
||||
|
||||
df = episodes_dataset.to_pandas()
|
||||
annotations = {}
|
||||
for ep_idx in df.index:
|
||||
names = df.loc[ep_idx, col_names]
|
||||
if names is None or (isinstance(names, float) and pd.isna(names)):
|
||||
continue
|
||||
starts, ends = df.loc[ep_idx, col_start], df.loc[ep_idx, col_end]
|
||||
annotations[int(ep_idx)] = SubtaskAnnotation(
|
||||
subtasks=[
|
||||
Subtask(
|
||||
name=n,
|
||||
timestamps=Timestamp(
|
||||
start=f"{int(s) // 60:02d}:{int(s) % 60:02d}",
|
||||
end=f"{int(e) // 60:02d}:{int(e) % 60:02d}",
|
||||
),
|
||||
)
|
||||
for n, s, e in zip(names, starts, ends)
|
||||
]
|
||||
)
|
||||
return annotations
|
||||
|
||||
|
||||
def process_single_episode(
|
||||
ep_idx: int,
|
||||
dataset_root: Path,
|
||||
dataset_meta,
|
||||
video_key: str,
|
||||
fps: int,
|
||||
annotator: VideoAnnotator,
|
||||
console: Console,
|
||||
) -> tuple[int, SubtaskAnnotation | None, str | None]:
|
||||
"""Process a single episode annotation."""
|
||||
try:
|
||||
video_path = dataset_root / dataset_meta.get_video_file_path(ep_idx, video_key)
|
||||
if not video_path.exists():
|
||||
return ep_idx, None, f"Video not found: {video_path}"
|
||||
|
||||
start = float(dataset_meta.episodes[f"videos/{video_key}/from_timestamp"][ep_idx])
|
||||
end = float(dataset_meta.episodes[f"videos/{video_key}/to_timestamp"][ep_idx])
|
||||
return ep_idx, annotator.annotate(video_path, fps, start, end), None
|
||||
except Exception as e:
|
||||
return ep_idx, None, str(e)
|
||||
|
||||
|
||||
def worker_process_episodes(
|
||||
worker_id: int,
|
||||
gpu_id: int,
|
||||
episode_indices: list[int],
|
||||
repo_id: str,
|
||||
video_key: str,
|
||||
sparse_subtask_list: list[str],
|
||||
dense_subtask_list: list[str] | None,
|
||||
model_name: str,
|
||||
torch_dtype: torch.dtype,
|
||||
) -> tuple[dict, dict | None]:
|
||||
"""Worker for parallel processing across GPUs."""
|
||||
device = f"cuda:{gpu_id}"
|
||||
console = Console()
|
||||
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||
|
||||
sparse_annotator = VideoAnnotator(sparse_subtask_list, model_name, device, torch_dtype)
|
||||
dense_annotator = (
|
||||
VideoAnnotator(
|
||||
dense_subtask_list,
|
||||
model_name,
|
||||
device,
|
||||
torch_dtype,
|
||||
sparse_annotator.model,
|
||||
sparse_annotator.processor,
|
||||
)
|
||||
if dense_subtask_list
|
||||
else None
|
||||
)
|
||||
|
||||
sparse_annotations, dense_annotations = {}, {} if dense_subtask_list else None
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
_, sparse_ann, err = process_single_episode(
|
||||
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, sparse_annotator, console
|
||||
)
|
||||
if sparse_ann:
|
||||
sparse_annotations[ep_idx] = sparse_ann
|
||||
|
||||
if dense_annotator:
|
||||
_, dense_ann, _ = process_single_episode(
|
||||
ep_idx, dataset.root, dataset.meta, video_key, dataset.fps, dense_annotator, console
|
||||
)
|
||||
if dense_ann:
|
||||
dense_annotations[ep_idx] = dense_ann
|
||||
|
||||
return sparse_annotations, dense_annotations
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SARM-style subtask annotation using local GPU (Qwen3-VL)")
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="HuggingFace dataset repository ID")
|
||||
parser.add_argument(
|
||||
"--sparse-subtasks", type=str, default=None, help="Comma-separated sparse subtask names"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dense-subtasks", type=str, default=None, help="Comma-separated dense subtask names"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dense-only", action="store_true", help="Dense-only mode with auto-generated sparse 'task' stage"
|
||||
)
|
||||
parser.add_argument("--episodes", type=int, nargs="+", default=None, help="Episode indices to annotate")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct", help="VLM model")
|
||||
parser.add_argument("--skip-existing", action="store_true", help="Skip already annotated episodes")
|
||||
parser.add_argument("--video-key", type=str, default=None, help="Video key (default: first available)")
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
|
||||
parser.add_argument("--output-repo-id", type=str, default=None, help="Output repo ID for push")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
|
||||
parser.add_argument("--num-workers", type=int, default=1, help="Parallel workers for multi-GPU")
|
||||
parser.add_argument("--gpu-ids", type=int, nargs="+", default=None, help="GPU IDs to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
console = Console()
|
||||
|
||||
# Validate arguments
|
||||
if args.dense_only and not args.dense_subtasks:
|
||||
return console.print("[red]Error: --dense-only requires --dense-subtasks[/red]")
|
||||
if args.dense_subtasks and not args.sparse_subtasks and not args.dense_only:
|
||||
return console.print("[red]Error: --dense-subtasks requires --sparse-subtasks or --dense-only[/red]")
|
||||
|
||||
sparse_subtask_list = (
|
||||
[s.strip() for s in args.sparse_subtasks.split(",")] if args.sparse_subtasks else None
|
||||
)
|
||||
dense_subtask_list = [s.strip() for s in args.dense_subtasks.split(",")] if args.dense_subtasks else None
|
||||
auto_sparse = sparse_subtask_list is None
|
||||
dense_mode = dense_subtask_list is not None
|
||||
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
|
||||
|
||||
console.print(f"[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||||
fps = dataset.fps
|
||||
|
||||
if not dataset.meta.video_keys:
|
||||
raise ValueError("No video keys found")
|
||||
|
||||
video_key = (
|
||||
args.video_key if args.video_key in (dataset.meta.video_keys or []) else dataset.meta.video_keys[0]
|
||||
)
|
||||
console.print(f"[cyan]Using camera: {video_key}, FPS: {fps}[/cyan]")
|
||||
|
||||
# Determine episodes
|
||||
episode_indices = args.episodes or list(range(dataset.meta.total_episodes))
|
||||
|
||||
existing_annotations = load_annotations_from_dataset(dataset.root, prefix="sparse")
|
||||
if args.skip_existing:
|
||||
episode_indices = [ep for ep in episode_indices if ep not in existing_annotations]
|
||||
|
||||
if not episode_indices:
|
||||
return console.print("[green]All episodes already annotated![/green]")
|
||||
console.print(f"[cyan]Annotating {len(episode_indices)} episodes[/cyan]")
|
||||
|
||||
# GPU setup
|
||||
gpu_ids = args.gpu_ids or list(
|
||||
range(min(args.num_workers, torch.cuda.device_count() if torch.cuda.is_available() else 1))
|
||||
)
|
||||
args.num_workers = len(gpu_ids)
|
||||
|
||||
sparse_annotations = existing_annotations.copy()
|
||||
dense_annotations = {} if dense_mode else None
|
||||
|
||||
# Auto-sparse mode
|
||||
if auto_sparse:
|
||||
sparse_annotations.update(generate_auto_sparse_annotations(dataset, episode_indices, video_key))
|
||||
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
|
||||
console.print(f"[green]Auto-generated {len(episode_indices)} sparse 'task' annotations[/green]")
|
||||
|
||||
# VLM annotation (for sparse if not auto, and for dense)
|
||||
need_vlm = (not auto_sparse) or dense_mode
|
||||
|
||||
if need_vlm:
|
||||
if args.num_workers > 1 and not auto_sparse:
|
||||
# Parallel processing
|
||||
console.print(f"[cyan]Parallel processing with {args.num_workers} workers[/cyan]")
|
||||
episodes_per_worker = [[] for _ in range(args.num_workers)]
|
||||
for i, ep_idx in enumerate(episode_indices):
|
||||
episodes_per_worker[i % args.num_workers].append(ep_idx)
|
||||
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=args.num_workers, mp_context=mp.get_context("spawn")
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
worker_process_episodes,
|
||||
w,
|
||||
gpu_ids[w],
|
||||
episodes_per_worker[w],
|
||||
args.repo_id,
|
||||
video_key,
|
||||
sparse_subtask_list,
|
||||
dense_subtask_list,
|
||||
args.model,
|
||||
torch_dtype,
|
||||
)
|
||||
for w in range(args.num_workers)
|
||||
if episodes_per_worker[w]
|
||||
]
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
worker_sparse, worker_dense = future.result()
|
||||
sparse_annotations.update(worker_sparse)
|
||||
if dense_mode and worker_dense:
|
||||
dense_annotations.update(worker_dense)
|
||||
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
|
||||
if dense_mode:
|
||||
save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Worker failed: {e}") from e
|
||||
else:
|
||||
# Sequential processing
|
||||
sparse_annotator = (
|
||||
VideoAnnotator(sparse_subtask_list, args.model, args.device, torch_dtype)
|
||||
if not auto_sparse and sparse_subtask_list
|
||||
else None
|
||||
)
|
||||
dense_annotator = (
|
||||
VideoAnnotator(
|
||||
dense_subtask_list,
|
||||
args.model,
|
||||
args.device,
|
||||
torch_dtype,
|
||||
sparse_annotator.model if sparse_annotator else None,
|
||||
sparse_annotator.processor if sparse_annotator else None,
|
||||
)
|
||||
if dense_mode
|
||||
else None
|
||||
)
|
||||
|
||||
for i, ep_idx in enumerate(episode_indices):
|
||||
console.print(f"[cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/cyan]")
|
||||
|
||||
if sparse_annotator:
|
||||
_, sparse_ann, err = process_single_episode(
|
||||
ep_idx, dataset.root, dataset.meta, video_key, fps, sparse_annotator, console
|
||||
)
|
||||
if sparse_ann:
|
||||
sparse_annotations[ep_idx] = sparse_ann
|
||||
save_annotations_to_dataset(dataset.root, sparse_annotations, fps, prefix="sparse")
|
||||
elif err:
|
||||
console.print(f"[red]Sparse failed: {err}[/red]")
|
||||
|
||||
if dense_annotator:
|
||||
_, dense_ann, err = process_single_episode(
|
||||
ep_idx, dataset.root, dataset.meta, video_key, fps, dense_annotator, console
|
||||
)
|
||||
if dense_ann:
|
||||
dense_annotations[ep_idx] = dense_ann
|
||||
save_annotations_to_dataset(dataset.root, dense_annotations, fps, prefix="dense")
|
||||
elif err:
|
||||
console.print(f"[red]Dense failed: {err}[/red]")
|
||||
|
||||
# Save temporal proportions
|
||||
def save_proportions(annotations, prefix, is_auto=False):
|
||||
props: dict[str, float] = {"task": 1.0} if is_auto else compute_temporal_proportions(annotations, fps)
|
||||
path = dataset.root / "meta" / f"temporal_proportions_{prefix}.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
json.dump(props, f, indent=2)
|
||||
console.print(f"[green]Saved {prefix} temporal proportions[/green]")
|
||||
|
||||
save_proportions(sparse_annotations, "sparse", auto_sparse)
|
||||
if dense_mode and dense_annotations:
|
||||
save_proportions(dense_annotations, "dense")
|
||||
|
||||
console.print(
|
||||
f"\n[bold green]Complete! {len(sparse_annotations)} sparse, {len(dense_annotations or {})} dense annotations[/bold green]"
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
try:
|
||||
dataset.push_to_hub(push_videos=True)
|
||||
console.print(f"[green]Pushed to {args.output_repo_id or args.repo_id}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Push failed: {e}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,44 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick test to verify the fix for task_indices length mismatch
|
||||
# This should now work correctly even with --num-samples < full dataset length
|
||||
|
||||
echo "Testing annotate_pgen.py with --num-samples=100 on full dataset..."
|
||||
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--data-dir /fsx/jade_choghari/.cache/huggingface/lerobot/lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--num-samples 100 \
|
||||
--sample-interval 1.0 \
|
||||
--output-dir /fsx/jade_choghari/outputs/pgen_test_fixed
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✓ SUCCESS: Script completed without errors!"
|
||||
echo ""
|
||||
echo "Verifying output..."
|
||||
|
||||
# Check that all frames have task_index_high_level
|
||||
python -c "
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
import numpy as np
|
||||
|
||||
ds = LeRobotDataset(repo_id='local_test', root='/fsx/jade_choghari/outputs/pgen_test_fixed')
|
||||
print(f'Dataset has {len(ds)} frames')
|
||||
print(f'Features: {list(ds.features.keys())}')
|
||||
|
||||
# Check that task_index_high_level exists
|
||||
assert 'task_index_high_level' in ds.features, 'task_index_high_level not in features!'
|
||||
|
||||
# Sample some frames
|
||||
for idx in [0, 50, 99, 100, 500, 1000, 11938]:
|
||||
if idx < len(ds):
|
||||
frame = ds[idx]
|
||||
task_idx = frame['task_index_high_level'].item()
|
||||
print(f'Frame {idx}: task_index_high_level = {task_idx}')
|
||||
|
||||
print('✓ All checks passed!')
|
||||
"
|
||||
else
|
||||
echo "✗ FAILED: Script exited with error code $?"
|
||||
fi
|
||||
|
||||
@@ -58,7 +58,6 @@ from lerobot.datasets.utils import (
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
load_tasks_high_level,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
@@ -162,7 +161,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.tasks_high_level = load_tasks_high_level(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -1052,12 +1050,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
# Optionally add high level task index
|
||||
if "task_index_high_level" in self.features:
|
||||
high_level_task_idx = item["task_index_high_level"].item()
|
||||
item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"]
|
||||
item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -60,7 +60,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -353,9 +352,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame:
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH)
|
||||
return tasks
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
|
||||
@@ -81,10 +81,14 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
fused: bool = False
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
# Fused optimizer only works on CUDA
|
||||
if kwargs.get("fused") and not torch.cuda.is_available():
|
||||
kwargs["fused"] = False
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -136,6 +136,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -164,6 +165,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
|
||||
@@ -94,6 +94,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
@@ -174,6 +175,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
|
||||
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
@@ -69,6 +74,7 @@ class PI0Config(PreTrainedConfig):
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW``
|
||||
@@ -136,6 +142,7 @@ class PI0Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, False],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
@@ -60,8 +65,8 @@ class PI05Config(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.MEAN_STD, # Pi0.5 uses quantiles for action
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
@@ -69,6 +74,7 @@ class PI05Config(PreTrainedConfig):
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
@@ -136,6 +142,7 @@ class PI05Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -41,17 +41,13 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TARGET_TOKENS,
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -340,6 +336,7 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -359,6 +356,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -433,8 +431,6 @@ class PaliGemmaWithExpertModel(
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
# prefix_output to be used for the language head
|
||||
# shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
@@ -524,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -584,13 +586,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
if dtype is not None:
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
@@ -609,29 +608,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
"""Embed images with SigLIP, prompt tokens, and optionally target tokens with embedding layer.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
prompt_tokens: Prompt tokens (input for generation)
|
||||
target_tokens: Target tokens to predict (can be None for inference)
|
||||
prompt_masks: Attention masks for prompt tokens
|
||||
target_masks: Attention masks for target tokens
|
||||
|
||||
Returns:
|
||||
embs: Concatenated embeddings [images, prompt_tokens, (target_tokens if provided)]
|
||||
pad_masks: Padding masks
|
||||
att_masks: Attention masks (with causal masking for target prediction if target_tokens provided)
|
||||
total_T_images: Total number of image tokens
|
||||
"""
|
||||
self, images, img_masks, tokens, masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
total_T_images = 0
|
||||
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
@@ -643,48 +626,29 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
||||
total_T_images += num_img_embs
|
||||
|
||||
# Process prompt tokens
|
||||
def prompt_embed_func(prompt_tokens):
|
||||
prompt_emb = self.paligemma_with_expert.embed_language_tokens(prompt_tokens)
|
||||
prompt_emb_dim = prompt_emb.shape[-1]
|
||||
return prompt_emb * math.sqrt(prompt_emb_dim)
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
prompt_emb = self._apply_checkpoint(prompt_embed_func, prompt_tokens)
|
||||
embs.append(prompt_emb)
|
||||
pad_masks.append(prompt_masks)
|
||||
# Process language tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
num_prompt_embs = prompt_emb.shape[1]
|
||||
att_masks += [0] * num_prompt_embs # Prompt tokens can attend to all previous tokens (images + prompt)
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(masks)
|
||||
|
||||
# Process target tokens if provided (these are predicted, so use causal masking)
|
||||
if target_tokens is not None:
|
||||
def target_embed_func(target_tokens):
|
||||
target_emb = self.paligemma_with_expert.embed_language_tokens(target_tokens)
|
||||
target_emb_dim = target_emb.shape[-1]
|
||||
return target_emb * math.sqrt(target_emb_dim)
|
||||
|
||||
target_emb = self._apply_checkpoint(target_embed_func, target_tokens)
|
||||
embs.append(target_emb)
|
||||
|
||||
# Create target pad masks (non-zero tokens are valid)
|
||||
pad_masks.append(target_masks)
|
||||
|
||||
num_target_embs = target_emb.shape[1]
|
||||
# Causal masking for target tokens: each target token can attend to images, all prompt tokens,
|
||||
# and previous target tokens
|
||||
att_masks += [1] * num_target_embs # Use 1 for causal attention on target tokens
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, total_T_images
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
@@ -733,20 +697,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
prompt_tokens: Prompt tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||
prompt_masks: Attention masks for prompt_tokens
|
||||
target_tokens: Target tokens to predict (e.g., tokens for "pick up the cup")
|
||||
target_masks: Attention masks for target_tokens
|
||||
actions: Ground truth actions
|
||||
noise: Optional noise for flow matching
|
||||
time: Optional time for flow matching
|
||||
"""
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
@@ -756,57 +708,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# Embed prefix (images + prompt_tokens + target_tokens)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks
|
||||
)
|
||||
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
# Prepare attention masks for prefix-only pass (for target token prediction)
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
|
||||
# prefix-only transformer run for target token prediction
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None], # SUFFIX = None
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# LM HEAD → TARGET LOGITS
|
||||
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_prompt + T_target
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
|
||||
# Extract logits for target token prediction (shifted by 1 for autoregressive training)
|
||||
# Position i predicts token i+1, so we take logits from positions before target tokens:
|
||||
# - Position (start_index-1) (last prompt token) predicts target_tokens[0]
|
||||
# - Position (start_index) (first target token) predicts target_tokens[1], etc.
|
||||
T_prompt = prompt_tokens.size(1)
|
||||
T_target = target_tokens.size(1)
|
||||
start_index = total_T_images + T_prompt
|
||||
end_index = start_index + T_target
|
||||
logits_target = logits[:, start_index-1:end_index-1, :] # (B, T_target, vocab)
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
# Reshape for loss computation
|
||||
logits_flat = logits_target.reshape(-1, logits_target.size(-1)) # (B*T_target, vocab)
|
||||
targets_flat = target_tokens.reshape(-1) # (B*T_target)
|
||||
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_target)
|
||||
loss_per_token = loss_per_token.reshape(target_tokens.shape) # (B, T_target)
|
||||
|
||||
# Apply mask and compute mean loss over valid tokens
|
||||
masked_loss = loss_per_token * target_masks.float()
|
||||
target_loss = masked_loss.sum() / target_masks.sum().clamp(min=1)
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
@@ -814,14 +719,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# Concatenate prefix (images + prompt_tokens + target_tokens) and suffix (actions) masks
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
# Prepare attention masks for full forward pass (prefix + suffix)
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
@@ -832,7 +736,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
# prefix_out to be used for the language head
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
@@ -847,104 +750,25 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
fm_loss = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
return {
|
||||
"flow_loss": fm_loss,
|
||||
"target_loss": target_loss,
|
||||
"loss": 10 * fm_loss.mean() + target_loss,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_target_tokens(
|
||||
self, images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_length, device
|
||||
):
|
||||
"""Generate target tokens autoregressively using next token prediction."""
|
||||
bsize = prompt_tokens.shape[0]
|
||||
|
||||
# Get lm_head for token generation
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# Embed prefix without target tokens first
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens=None, prompt_masks=prompt_masks, target_masks=None
|
||||
)
|
||||
|
||||
# Initialize generated tokens list
|
||||
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
||||
|
||||
for t in range(max_length):
|
||||
# Prepare attention masks for current prefix
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
|
||||
# Forward pass through model to get logits
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Get logits from the last position
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
next_token_logits = logits[:, -1, :] # (B, vocab)
|
||||
|
||||
# Greedy decoding - take the most likely token
|
||||
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
||||
|
||||
# Store generated token
|
||||
generated_tokens[:, t] = next_token
|
||||
|
||||
# Check for EOS token - if all batches have generated EOS, stop
|
||||
if tokenizer.eos_token_id is not None:
|
||||
if (next_token == tokenizer.eos_token_id).all():
|
||||
break
|
||||
|
||||
# Embed the generated token and append to prefix
|
||||
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
||||
|
||||
def next_token_embed_func(next_token_unsqueezed):
|
||||
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
||||
next_emb_dim = next_emb.shape[-1]
|
||||
return next_emb * math.sqrt(next_emb_dim)
|
||||
|
||||
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
|
||||
|
||||
# Append to prefix embeddings
|
||||
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
|
||||
|
||||
# Update masks - new token is valid and uses causal attention
|
||||
prefix_pad_masks = torch.cat([
|
||||
prefix_pad_masks,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1)
|
||||
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
return generated_tokens
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
prompt_tokens,
|
||||
prompt_masks,
|
||||
tokens,
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
tokenizer=None,
|
||||
max_target_tokens=50,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
bsize = prompt_tokens.shape[0]
|
||||
device = prompt_tokens.device
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension as expected by action_in_proj
|
||||
@@ -955,33 +779,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
# Generate target tokens autoregressively during inference (if tokenizer provided)
|
||||
generated_target_tokens = None
|
||||
target_masks = None
|
||||
if tokenizer is not None:
|
||||
generated_target_tokens = self._generate_target_tokens(
|
||||
images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_target_tokens, device
|
||||
)
|
||||
|
||||
# Decode and print the generated target tokens
|
||||
for i in range(bsize):
|
||||
# Remove padding tokens (0) and special tokens
|
||||
valid_tokens = generated_target_tokens[i][generated_target_tokens[i] != 0]
|
||||
decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
print(f"[Inference] Generated target {i}: {decoded_text}")
|
||||
|
||||
# Create mask for generated tokens (all valid where token != 0)
|
||||
target_masks = generated_target_tokens != 0
|
||||
|
||||
# Embed prefix with prompt and optionally generated target tokens
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens=generated_target_tokens,
|
||||
prompt_masks=prompt_masks, target_masks=target_masks
|
||||
)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -993,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
@@ -1026,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
@@ -1058,7 +853,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=suffix_embs.dtype)
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
@@ -1103,14 +898,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
|
||||
# Load tokenizer for subtask decoding
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -1411,16 +1198,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
# Use prompt tokens (WITHOUT target) for inference - we'll generate the target
|
||||
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, prompt_tokens, prompt_masks,
|
||||
tokenizer=self.tokenizer,
|
||||
**kwargs
|
||||
)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1433,24 +1214,22 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||
target_tokens, target_masks = batch[f"{OBS_LANGUAGE_TARGET_TOKENS}"], batch[f"{OBS_LANGUAGE_TARGET_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
# prompt_tokens = instruction tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||
# target_tokens = target tokens to predict (e.g., "pick up the cup")
|
||||
loss_dict = self.model.forward(images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Extract the total loss
|
||||
loss = loss_dict["loss"]
|
||||
|
||||
# Prepare detailed loss dictionary for logging
|
||||
detailed_loss_dict = {
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[:, :, :original_action_dim]
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||
"target_loss": loss_dict["target_loss"].item(),
|
||||
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
|
||||
}
|
||||
|
||||
return loss, detailed_loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -47,15 +47,13 @@ from lerobot.utils.constants import (
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
prompt_key: str = "prompt"
|
||||
target_key: str = "target"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -66,8 +64,6 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("user_prompt")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
@@ -80,33 +76,16 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Clean high level tasks first (if available)
|
||||
cleaned_high_level_tasks = []
|
||||
if high_level_tasks is not None:
|
||||
for high_level_task in high_level_tasks:
|
||||
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
|
||||
|
||||
# Process tasks to create prompts (input) and targets (what to predict)
|
||||
prompts = [] # Input prompts ending with "Subtask:"
|
||||
targets = [] # Target text to predict (the subtask)
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
|
||||
# Store the subtask text as target for prediction
|
||||
targets.append(cleaned_text)
|
||||
|
||||
if cleaned_high_level_tasks:
|
||||
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||
# Prompt ends with "Subtask:" - model will predict the target
|
||||
prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
||||
else:
|
||||
raise ValueError("No high level tasks found")
|
||||
|
||||
prompts.append(prompt)
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.prompt_key] = prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.target_key] = targets
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
@@ -154,14 +133,14 @@ def make_pi05_pre_post_processors(
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateAndLanguageTokenizerProcessorStep
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
|
||||
@@ -79,6 +79,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
optimizer_fused: bool = False
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
@@ -136,6 +137,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module):
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module):
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -168,12 +168,10 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **user_prompt_key, **subtask_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -47,6 +47,7 @@ class RenameObservationsProcessorStep(ObservationProcessorStep):
|
||||
processed_obs[self.rename_map[key]] = value
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -29,14 +29,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_TARGET_TOKENS,
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
@@ -59,9 +52,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||
token IDs and attention mask to the `observation` dictionary.
|
||||
|
||||
Optionally, this step can also tokenize a prompt (input for generation) and/or
|
||||
a target (text to predict) if present in the complementary data, creating separate tokenized observations.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
@@ -69,8 +59,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
max_length: The maximum length to pad or truncate sequences to.
|
||||
task_key: The key in `complementary_data` where the task string is stored.
|
||||
prompt_key: The key in `complementary_data` where the prompt (input for generation) is stored.
|
||||
target_key: The key in `complementary_data` where the target (text to predict) is stored.
|
||||
padding_side: The side to pad on ('left' or 'right').
|
||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||
truncation: Whether to truncate sequences longer than `max_length`.
|
||||
@@ -81,8 +69,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
prompt_key: str = "prompt"
|
||||
target_key: str = "target"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
@@ -135,7 +121,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
raise ValueError("Complementary data is None so no task can be extracted from it")
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
|
||||
if task is None:
|
||||
raise ValueError("Task extracted from Complementary data is None")
|
||||
|
||||
@@ -147,60 +132,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the prompt (input for generation) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of prompt strings, or None if the prompt key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
prompt = complementary_data.get(self.prompt_key)
|
||||
|
||||
if prompt is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(prompt, str):
|
||||
return [prompt]
|
||||
elif isinstance(prompt, list) and all(isinstance(t, str) for t in prompt):
|
||||
return prompt
|
||||
|
||||
return None
|
||||
|
||||
def get_target(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the target (text to predict) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of target strings, or None if the target key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
target = complementary_data.get(self.target_key)
|
||||
|
||||
if target is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(target, str):
|
||||
return [target]
|
||||
elif isinstance(target, list) and all(isinstance(t, str) for t in target):
|
||||
return target
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -238,38 +169,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize prompt (input for generation) if available
|
||||
prompt = self.get_prompt(self.transition)
|
||||
if prompt is not None:
|
||||
tokenized_prompt_input = self._tokenize_text(prompt)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_prompt_input = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt_input.items()
|
||||
}
|
||||
|
||||
# Add prompt tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_PROMPT_TOKENS] = tokenized_prompt_input["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = tokenized_prompt_input["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize target (text to predict) if available
|
||||
target = self.get_target(self.transition)
|
||||
if target is not None:
|
||||
tokenized_target = self._tokenize_text(target)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_target = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_target.items()
|
||||
}
|
||||
|
||||
# Add target tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TARGET_TOKENS] = tokenized_target["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_TARGET_ATTENTION_MASK] = tokenized_target["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -330,8 +229,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"prompt_key": self.prompt_key,
|
||||
"target_key": self.target_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
@@ -370,26 +267,4 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for prompt tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for target tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_TARGET_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_TARGET_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
and remove features. When new_repo_id is specified, creates a new dataset.
|
||||
remove features, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
@@ -65,6 +66,25 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Convert image dataset to video format (saves locally):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
Convert and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -72,9 +92,13 @@ Using JSON config file:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertToVideoConfig:
|
||||
type: str = "convert_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def encode_episode_videos(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
episode_index: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
temp_dir: Path,
|
||||
num_image_workers: int = 4,
|
||||
) -> dict[str, dict]:
|
||||
"""Encode videos for a single episode and return video metadata.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images
|
||||
new_meta: Metadata object for the new video dataset
|
||||
episode_index: Episode index to process
|
||||
vcodec: Video codec
|
||||
pix_fmt: Pixel format
|
||||
g: Group of pictures size
|
||||
crf: Constant rate factor
|
||||
fast_decode: Fast decode tuning
|
||||
temp_dir: Temporary directory for images
|
||||
num_image_workers: Number of workers for saving images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
video_metadata = {}
|
||||
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily
|
||||
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||
|
||||
# Determine chunk and file indices
|
||||
# For simplicity, we'll put each episode in its own file
|
||||
chunk_idx = episode_index // new_meta.chunks_size
|
||||
file_idx = episode_index % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
return video_metadata
|
||||
|
||||
|
||||
def convert_dataset_to_videos(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-based dataset to video-based dataset.
|
||||
|
||||
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each episode
|
||||
all_episode_metadata = []
|
||||
|
||||
try:
|
||||
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||
# Get episode metadata from source
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
|
||||
# Encode videos for this episode
|
||||
video_metadata = encode_episode_videos(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
episode_index=ep_idx,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
temp_dir=temp_dir,
|
||||
num_image_workers=num_workers,
|
||||
)
|
||||
|
||||
# Build episode metadata
|
||||
episode_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": src_episode["length"],
|
||||
"dataset_from_index": ep_idx * src_episode["length"],
|
||||
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info (using same structure as source)
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
from lerobot.datasets.utils import write_info
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
output_repo_id = cfg.new_repo_id
|
||||
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
# Extract repo name from output_dir for the dataset
|
||||
output_repo_id = output_dir.name
|
||||
logging.info(f"Saving to local directory: {output_dir}")
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
new_dataset = convert_dataset_to_videos(
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
)
|
||||
|
||||
logging.info("Video dataset created successfully!")
|
||||
logging.info(f"Location: {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
|
||||
logging.info(f"Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}...")
|
||||
new_dataset.push_to_hub()
|
||||
logging.info("✓ Successfully pushed to hub!")
|
||||
else:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "convert_to_video":
|
||||
handle_convert_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,12 +26,7 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_PROMPT = OBS_STR + ".prompt"
|
||||
OBS_LANGUAGE_PROMPT_TOKENS = OBS_LANGUAGE_PROMPT + ".tokens"
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_PROMPT + ".attention_mask"
|
||||
OBS_LANGUAGE_TARGET = OBS_STR + ".target"
|
||||
OBS_LANGUAGE_TARGET_TOKENS = OBS_LANGUAGE_TARGET + ".tokens"
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK = OBS_LANGUAGE_TARGET + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1047,3 +1048,107 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
|
||||
assert new_file_indices == original_file_indices, "File indices should be preserved"
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed)
|
||||
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1])
|
||||
|
||||
output_dir = tmp_path / "pusht_video"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Verify source dataset has images, not videos
|
||||
assert len(source_dataset.meta.video_keys) == 0
|
||||
assert "observation.image" in source_dataset.meta.features
|
||||
|
||||
# Convert to video dataset (only first 2 episodes for speed)
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
episode_indices=[0, 1],
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
# Verify new dataset has videos
|
||||
assert len(video_dataset.meta.video_keys) > 0
|
||||
assert "observation.image" in video_dataset.meta.video_keys
|
||||
|
||||
# Verify correct number of episodes and frames (2 episodes)
|
||||
assert video_dataset.meta.total_episodes == 2
|
||||
# Compare against the actual number of frames in the loaded episodes, not metadata total
|
||||
assert len(video_dataset) == len(source_dataset)
|
||||
|
||||
# Verify video files exist
|
||||
for ep_idx in range(video_dataset.meta.total_episodes):
|
||||
for video_key in video_dataset.meta.video_keys:
|
||||
video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
assert video_path.exists(), f"Video file should exist: {video_path}"
|
||||
|
||||
# Verify we can load the dataset and access it
|
||||
assert len(video_dataset) == video_dataset.meta.total_frames
|
||||
|
||||
# Test that we can actually get an item from the video dataset
|
||||
item = video_dataset[0]
|
||||
assert "observation.image" in item
|
||||
assert "action" in item
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load the actual lerobot/pusht_image dataset (only first 3 episodes)
|
||||
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2])
|
||||
|
||||
output_dir = tmp_path / "pusht_video_subset"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Convert only episode 0 to video (subset of loaded episodes)
|
||||
episode_indices = [0]
|
||||
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video_subset",
|
||||
episode_indices=episode_indices,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
# Verify correct number of episodes
|
||||
assert video_dataset.meta.total_episodes == len(episode_indices)
|
||||
|
||||
# Verify video files exist for selected episodes
|
||||
assert len(video_dataset.meta.video_keys) > 0
|
||||
assert "observation.image" in video_dataset.meta.video_keys
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
@@ -266,7 +266,7 @@ def create_original_observation_with_openpi_preprocessing(batch):
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateAndLanguageTokenizerProcessorStep)
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user