Files
lerobot/examples/dataset/annotate_pgen.py
T
2025-12-10 16:22:54 +00:00

1536 lines
57 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Synthetic Data Generation for Hi-Robot Style Hierarchical Policy Training.
This script generates synthetic user prompts (_t) and robot utterances (u_t) for
hierarchical policy training using Qwen VLM as the generator model (pgen).
The pipeline:
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
2. For each frame, generates synthetic dialogue based on:
- Visual context (images at time t OR video clips in video mode)
- Current skill being performed
- History of previous skills
- High-level task description
3. Saves results as high-level tasks and updates dataset with task_index_high_level
Modes:
- Image Mode (default): Samples frames at intervals and sends images to the model
- Video Mode (--video-mode): Passes entire skill video clips to the model
Usage:
```bash
# Image mode (default)
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--output-dir /path/to/output
# Video mode with batch processing
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--video-mode \
--video-key observation.images.base \
--video-batch-size 4
```
"""
import argparse
import json
import re
import subprocess
import tempfile
import textwrap
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from tqdm import tqdm
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# =============================================================================
# Video Extraction Utilities
# =============================================================================
class VideoExtractor:
"""Utilities for extracting and processing video segments from LeRobot datasets."""
def __init__(self, console: Console | None = None):
self.console = console or Console()
def extract_episode_video(
self,
video_path: Path,
start_timestamp: float,
end_timestamp: float,
target_fps: int = 1,
) -> Path:
"""
Extract a specific episode segment from a concatenated video file.
Args:
video_path: Path to the source video file
start_timestamp: Start time in seconds
end_timestamp: End time in seconds
target_fps: Target frames per second for output
Returns:
Path to the extracted temporary video file
"""
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp_path = Path(tmp_file.name)
tmp_file.close()
duration = end_timestamp - start_timestamp
cmd = [
"ffmpeg",
"-i",
str(video_path),
"-ss",
str(start_timestamp),
"-t",
str(duration),
"-r",
str(target_fps),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-an",
"-y",
str(tmp_path),
]
try:
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"FFmpeg failed: {e}") from e
except FileNotFoundError:
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("Video extraction produced invalid file")
return tmp_path
def get_video_duration(self, video_path: Path) -> float:
"""Get duration of a video file in seconds."""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return frame_count / fps
# =============================================================================
# Prompt Template for pgen
# =============================================================================
PGEN_PROMPT_TEMPLATE_IMAGE = textwrap.dedent("""\
# Role
You are a robot-assistant dialogue generator for hierarchical robot policies.
# Task
You will receive:
- A list of images showing the current robot scene at time t
- The high-level task: {task_description}
- Previous skill steps completed: {skill_history}
- The next skill to be performed by the robot: {skill_current}
# Your Goal
Generate two things that create a natural human-robot interaction:
1. **user_prompt**: A natural-sounding user request that logically leads to the robot
performing the skill "{skill_current}" given the task context and history.
2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request.
# Guidelines
- The user prompt should be grounded in the visual scene and task context
- Vary interaction types: direct commands, implicit requests, corrections, constraints
- Examples of user prompt styles:
* Direct: "Can you pick up the red brick?"
* Implicit: "I need something red for the tower"
* Negative: "Don't pick up the blue one"
* Constraint: "Make sure to handle it gently"
* Correction: "Actually, move to the other box instead"
- Robot responses should be appropriate: confirmations, clarifications, or error handling
- Use the skill history to ensure continuity (don't repeat past actions)
- Consider world knowledge (dietary preferences, object properties, etc.)
# Scenario Types (choose one that fits):
- **specific_object**: User specifies exact object/action
- **negative_task**: User says what NOT to do
- **situated_correction**: User adjusts based on current state
- **implicit_request**: User implies need without direct command
- **constraint_based**: User adds specific constraints
# Response Types (choose one that fits):
- **confirmation**: Simple "OK, I'll do X"
- **clarification**: "Just to confirm, you want me to..."
- **acknowledgment**: "Got it, [doing action]"
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
# Output Format
Respond ONLY with valid JSON:
{{
"scenario_type": "one of the types above",
"response_type": "one of the types above",
"user_prompt": "natural user request here",
"robot_utterance": "natural robot response here"
}}
The responses must be grounded in the visual scene, the task, and the skill history.
Make it sound like a real human-robot interaction.
""")
PGEN_PROMPT_TEMPLATE_VIDEO = textwrap.dedent("""\
# Role
You are a robot-assistant dialogue generator for hierarchical robot policies.
# Task
You are watching a full robot demonstration video for the task: {task_description}
For each timestamp below, generate natural human-robot dialogue that would have led to the observed behavior.
At each timestamp, you'll see:
- What skills have been completed so far (cumulative history)
- What skill is currently being executed
{timestamp_context}
# Your Goal
For EACH timestamp, generate:
1. **user_prompt**: A natural user request that would lead to the robot performing the current skill
2. **robot_utterance**: A natural robot response acknowledging the request
# Guidelines
- Watch the video from start to each timestamp to understand the context
- Ground prompts in what's visible in the video at that time
- Vary interaction types: direct commands, implicit requests, corrections, constraints
- Examples of user prompt styles:
* Direct: "Can you pick up the red brick?"
* Implicit: "I need something red for the tower"
* Negative: "Don't pick up the blue one"
* Constraint: "Make sure to handle it gently"
* Correction: "Actually, move to the other box instead"
- Robot responses should be appropriate: confirmations, clarifications, or error handling
- Ensure continuity across timestamps (don't contradict earlier dialogue)
- Consider world knowledge (dietary preferences, object properties, etc.)
# Scenario Types:
- **specific_object**: User specifies exact object/action
- **negative_task**: User says what NOT to do
- **situated_correction**: User adjusts based on current state
- **implicit_request**: User implies need without direct command
- **constraint_based**: User adds specific constraints
# Response Types:
- **confirmation**: Simple "OK, I'll do X"
- **clarification**: "Just to confirm, you want me to..."
- **acknowledgment**: "Got it, [doing action]"
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
# Output Format
Respond ONLY with valid JSON array:
[
{{
"timestamp": timestamp_value,
"scenario_type": "one of the types above",
"response_type": "one of the types above",
"user_prompt": "natural user request here",
"robot_utterance": "natural robot response here"
}},
... (one entry per timestamp)
]
Make it sound like a real human-robot interaction grounded in the video.
""")
def construct_prompt_image(
task_description: str,
skill_history: list[str],
skill_current: str,
) -> str:
"""
Construct the text prompt for pgen in image mode.
Args:
task_description: High-level task description
skill_history: List of previously completed skills
skill_current: Current skill to be performed
Returns:
Formatted prompt string
"""
# Format skill history nicely
if skill_history:
history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context
if len(skill_history) > 5:
history_str = f"... {history_str}"
else:
history_str = "None (starting the task)"
return PGEN_PROMPT_TEMPLATE_IMAGE.format(
task_description=task_description,
skill_history=history_str,
skill_current=skill_current,
)
def construct_prompt_video(
task_description: str,
timestamps_with_skills: list[dict],
) -> str:
"""
Construct the text prompt for pgen in video mode.
Args:
task_description: High-level task description
timestamps_with_skills: List of dicts with keys:
- timestamp: float
- skills_so_far: list[str]
- current_skill: str
Returns:
Formatted prompt string
"""
# Build timestamp context
timestamp_lines = []
for item in timestamps_with_skills:
ts = item["timestamp"]
skills_so_far = item["skills_so_far"]
current_skill = item["current_skill"]
if skills_so_far:
skills_str = ", ".join(f'"{s}"' for s in skills_so_far)
else:
skills_str = "None (starting)"
timestamp_lines.append(
f"- **Timestamp {ts:.2f}s**: Skills completed: [{skills_str}] | Current skill: \"{current_skill}\""
)
timestamp_context = "\n".join(timestamp_lines)
return PGEN_PROMPT_TEMPLATE_VIDEO.format(
task_description=task_description,
timestamp_context=timestamp_context,
)
# =============================================================================
# Qwen VLM Interface
# =============================================================================
class QwenPgen:
"""Qwen VLM wrapper for synthetic dialogue generation."""
def __init__(
self,
model_name: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
temperature: float = 0.7,
):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
self.console = Console()
self.device = device
self.model_name = model_name
self.temperature = temperature
self.process_vision_info = process_vision_info
self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]")
# Load model based on name
if "qwen3" in model_name.lower():
from transformers import Qwen3VLMoeForConditionalGeneration
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def call_qwen(
self,
images: list[Image.Image | str] | None = None,
prompt: str = "",
video: str | Path | None = None,
) -> dict[str, str]:
"""
Call Qwen VLM to generate synthetic dialogue for a single request.
Args:
images: List of PIL Images or image paths (for image mode)
prompt: Text prompt for generation
video: Path to video file (for video mode)
Returns:
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
# Use batch method with single item
results = self.call_qwen_batch(
batch_images=[images] if images else [None],
batch_prompts=[prompt],
batch_videos=[video] if video else [None],
)
return results[0]
def call_qwen_batch(
self,
batch_images: list[list[Image.Image | str] | None],
batch_prompts: list[str],
batch_videos: list[str | Path | None] | None = None,
) -> list[dict[str, str]]:
"""
Call Qwen VLM to generate synthetic dialogue for a batch of requests.
Args:
batch_images: List of image lists, one per request (None for video mode)
batch_prompts: List of text prompts, one per request
batch_videos: List of video paths, one per request (None for image mode)
Returns:
List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
if batch_videos is None:
batch_videos = [None] * len(batch_images)
if len(batch_images) != len(batch_prompts) or len(batch_images) != len(batch_videos):
raise ValueError(
f"Batch size mismatch: {len(batch_images)} image lists vs "
f"{len(batch_prompts)} prompts vs {len(batch_videos)} videos"
)
batch_size = len(batch_images)
if batch_size == 0:
return []
# Build messages for each item in batch
all_messages = []
for images, prompt, video in zip(batch_images, batch_prompts, batch_videos):
content = []
# Add video or images
if video is not None:
# Video mode
content.append({"type": "video", "video": str(video), "fps": 1.0})
elif images is not None:
# Image mode
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
messages = [
{
"role": "user",
"content": content,
}
]
all_messages.append(messages)
# Process all inputs
texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
texts.append(text)
image_inputs, video_inputs = self.process_vision_info(messages)
all_image_inputs.append(image_inputs)
all_video_inputs.append(video_inputs)
# Flatten image and video inputs for batch processing
# The processor expects a flat list of images across all batch items
flat_images = []
for img_list in all_image_inputs:
if img_list is not None:
if isinstance(img_list, list):
flat_images.extend(img_list)
else:
flat_images.append(img_list)
flat_videos = []
for vid_list in all_video_inputs:
if vid_list is not None:
if isinstance(vid_list, list):
flat_videos.extend(vid_list)
else:
flat_videos.append(vid_list)
# Process batch
inputs = self.processor(
text=texts,
images=flat_images if flat_images else None,
videos=flat_videos if flat_videos else None,
padding=True,
return_tensors="pt",
).to(self.device)
# Generate
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=self.temperature,
)
# Decode responses
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse all responses
results = []
for response in responses:
try:
parsed = self._parse_response(response.strip())
results.append(parsed)
except Exception as e:
self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]")
# Return empty/default result
results.append({
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
})
return results
def _parse_response(self, response: str) -> dict[str, str]:
"""Parse JSON response from model."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
return {
"scenario_type": data.get("scenario_type", "specific_object"),
"response_type": data.get("response_type", "confirmation"),
"user_prompt": data.get("user_prompt", ""),
"robot_utterance": data.get("robot_utterance", ""),
}
except json.JSONDecodeError:
# Try to find JSON object in response
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
return {
"scenario_type": data.get("scenario_type", "specific_object"),
"response_type": data.get("response_type", "confirmation"),
"user_prompt": data.get("user_prompt", ""),
"robot_utterance": data.get("robot_utterance", ""),
}
raise ValueError(f"Could not parse response: {response[:200]}...")
# =============================================================================
# Annotation Pipeline
# =============================================================================
def load_skills_metadata(dataset_root: Path) -> dict | None:
"""Load skills.json metadata from annotated dataset."""
skills_path = dataset_root / "meta" / "skills.json"
if skills_path.exists():
with open(skills_path) as f:
return json.load(f)
return None
def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
"""Find which skill covers a given timestamp."""
for skill in skills:
if skill["start"] <= timestamp < skill["end"]:
return skill["name"]
# Handle last frame
if timestamp >= skill["end"] and skill == skills[-1]:
return skill["name"]
return skills[-1]["name"] if skills else None
def annotate_sample_image(
pgen: QwenPgen,
images: list[Image.Image | str],
task_description: str,
skill_history: list[str],
skill_current: str,
) -> dict[str, str]:
"""
Generate synthetic dialogue for a single sample using images.
Args:
pgen: Qwen model wrapper
images: List of images at current timestep
task_description: High-level task description
skill_history: Previous skills completed
skill_current: Current skill being performed
Returns:
Dictionary with generated dialogue
"""
prompt = construct_prompt_image(task_description, skill_history, skill_current)
result = pgen.call_qwen(images=images, prompt=prompt, video=None)
return result
def annotate_episode_video(
pgen: QwenPgen,
video: str | Path,
task_description: str,
timestamps_with_skills: list[dict],
) -> list[dict[str, Any]]:
"""
Generate synthetic dialogue for an entire episode using video.
Args:
pgen: Qwen model wrapper
video: Path to episode video file
task_description: High-level task description
timestamps_with_skills: List of dicts with timestamp, skills_so_far, current_skill
Returns:
List of dictionaries with generated dialogue, one per timestamp
"""
# Use batch method with single episode
results = annotate_episodes_video_batch(
pgen=pgen,
batch_videos=[video],
batch_task_descriptions=[task_description],
batch_timestamps_with_skills=[timestamps_with_skills],
)
return results[0]
def annotate_episodes_video_batch(
pgen: QwenPgen,
batch_videos: list[str | Path],
batch_task_descriptions: list[str],
batch_timestamps_with_skills: list[list[dict]],
) -> list[list[dict[str, Any]]]:
"""
Generate synthetic dialogue for multiple episodes using videos in batch.
Args:
pgen: Qwen model wrapper
batch_videos: List of paths to episode video files
batch_task_descriptions: List of high-level task descriptions
batch_timestamps_with_skills: List of timestamp lists, one per episode
Returns:
List of result lists, one per episode (each containing dicts with generated dialogue)
"""
batch_size = len(batch_videos)
if batch_size == 0:
return []
# Build messages for each episode
all_messages = []
for video, task_desc, timestamps_with_skills in zip(
batch_videos, batch_task_descriptions, batch_timestamps_with_skills
):
prompt = construct_prompt_video(task_desc, timestamps_with_skills)
content = [
{"type": "video", "video": str(video), "fps": 1.0},
{"type": "text", "text": prompt},
]
messages = [{"role": "user", "content": content}]
all_messages.append(messages)
# Process all episodes through Qwen in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = pgen.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = pgen.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = pgen.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(pgen.device)
with torch.no_grad():
generated_ids = pgen.model.generate(
**inputs,
max_new_tokens=2048, # Larger for multiple timestamps per episode
do_sample=True,
temperature=pgen.temperature,
)
responses = pgen.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_results = []
for response, timestamps_with_skills in zip(responses, batch_timestamps_with_skills):
results = _parse_video_response(response.strip(), timestamps_with_skills)
all_results.append(results)
return all_results
def _parse_video_response(response: str, timestamps_with_skills: list[dict]) -> list[dict[str, Any]]:
"""Parse JSON array response from video mode."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
if not isinstance(data, list):
# If it's a dict with a list inside
if "annotations" in data:
data = data["annotations"]
elif "results" in data:
data = data["results"]
else:
raise ValueError("Expected JSON array or dict with 'annotations'/'results' key")
results = []
for item in data:
results.append({
"timestamp": item.get("timestamp", 0.0),
"scenario_type": item.get("scenario_type", "specific_object"),
"response_type": item.get("response_type", "confirmation"),
"user_prompt": item.get("user_prompt", ""),
"robot_utterance": item.get("robot_utterance", ""),
})
return results
except json.JSONDecodeError:
# Try to find JSON array in response
match = re.search(r"\[.*\]", response, re.DOTALL)
if match:
data = json.loads(match.group())
results = []
for item in data:
results.append({
"timestamp": item.get("timestamp", 0.0),
"scenario_type": item.get("scenario_type", "specific_object"),
"response_type": item.get("response_type", "confirmation"),
"user_prompt": item.get("user_prompt", ""),
"robot_utterance": item.get("robot_utterance", ""),
})
return results
breakpoint()
# Fallback: return empty results for each timestamp
print(f"Warning: Could not parse video response: {response[:200]}...")
return [
{
"timestamp": ts["timestamp"],
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
}
for ts in timestamps_with_skills
]
def _generate_synthetic_data_video_mode(
dataset: LeRobotDataset,
pgen: QwenPgen,
skills_metadata: dict,
video_key: str,
video_extractor: VideoExtractor,
console: Console,
sample_interval_seconds: float = 1.0,
batch_size: int = 1,
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
"""
Generate synthetic dialogue data using video mode with batched VLM calls.
The VLM sees full episode videos and generates dialogue for multiple
timestamps per episode, with cumulative skill history at each timestamp.
Args:
dataset: LeRobot dataset with skill annotations
pgen: Qwen model wrapper
skills_metadata: Loaded skills.json metadata
video_key: Video observation key (e.g., 'observation.images.base')
video_extractor: VideoExtractor instance
console: Rich console for logging
sample_interval_seconds: Sample timestamps at this interval
batch_size: Number of episodes to process in each VLM batch call
Returns:
Tuple of (tasks_df, task_indices_array, debug_outputs)
"""
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
episodes = skills_metadata.get("episodes", {})
# Track unique high-level tasks
high_level_tasks = {}
task_index_counter = 0
# Array to store task index for each frame
full_dataset_length = len(dataset)
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
debug_outputs = []
timestamps_processed = 0
console.print(f"[cyan]Processing {len(episodes)} episodes in VIDEO MODE with batch_size={batch_size}...[/cyan]")
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s[/cyan]")
# Convert episodes dict to list for batching
episode_list = list(episodes.items())
# Process episodes in batches
for batch_start in tqdm(range(0, len(episode_list), batch_size), desc="Processing episode batches"):
batch_end = min(batch_start + batch_size, len(episode_list))
batch_episodes = episode_list[batch_start:batch_end]
# Collect data for this batch
batch_data = []
extracted_videos = []
for episode_key, episode_data in batch_episodes:
episode_idx = int(episode_key)
skills = episode_data.get("skills", [])
description = episode_data.get("description", coarse_description)
if not skills:
console.print(f"[yellow]Warning: Episode {episode_idx} has no skills[/yellow]")
continue
# Get video path and extract full episode
extracted_path = None
try:
video_path = dataset.root / dataset.meta.get_video_file_path(episode_idx, video_key)
if not video_path.exists():
console.print(f"[yellow]Warning: Video not found for episode {episode_idx}[/yellow]")
continue
# Get episode timestamps
ep = dataset.meta.episodes[episode_idx]
episode_start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
episode_end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = episode_end_ts - episode_start_ts
# Extract FULL episode video
extracted_path = video_extractor.extract_episode_video(
video_path, episode_start_ts, episode_end_ts, target_fps=1
)
extracted_videos.append(extracted_path)
except Exception as e:
console.print(f"[yellow]Warning: Failed to extract video for episode {episode_idx}: {e}[/yellow]")
continue
# Build list of timestamps to sample
timestamps_with_skills = []
current_time = 0.0
while current_time <= duration:
# Find which skill is active at this timestamp
current_skill = None
skills_so_far = []
for skill in skills:
if skill["end"] <= current_time:
skills_so_far.append(skill["name"])
elif skill["start"] <= current_time < skill["end"]:
current_skill = skill["name"]
break
elif current_time >= skill["end"] and skill == skills[-1]:
current_skill = skill["name"]
break
if current_skill:
timestamps_with_skills.append({
"timestamp": current_time,
"skills_so_far": skills_so_far.copy(),
"current_skill": current_skill,
})
current_time += sample_interval_seconds
if not timestamps_with_skills:
console.print(f"[yellow]Warning: No valid timestamps for episode {episode_idx}[/yellow]")
continue
# Store batch item
batch_data.append({
"episode_idx": episode_idx,
"episode_metadata": ep,
"video_path": extracted_path,
"task_description": description,
"timestamps_with_skills": timestamps_with_skills,
"skills": skills,
})
if not batch_data:
continue
# BATCHED VLM CALL for all episodes in batch
try:
batch_results = annotate_episodes_video_batch(
pgen=pgen,
batch_videos=[item["video_path"] for item in batch_data],
batch_task_descriptions=[item["task_description"] for item in batch_data],
batch_timestamps_with_skills=[item["timestamps_with_skills"] for item in batch_data],
)
# Process results for each episode in batch
for item, results in zip(batch_data, batch_results):
episode_idx = item["episode_idx"]
ep = item["episode_metadata"]
timestamps_with_skills = item["timestamps_with_skills"]
description = item["task_description"]
timestamps_processed += len(results)
# Map results back to timestamps and create task indices
timestamp_to_result = {}
for result in results:
ts = result["timestamp"]
timestamp_to_result[ts] = result
# Process each sampled timestamp
for ts_info in timestamps_with_skills:
ts = ts_info["timestamp"]
result = timestamp_to_result.get(ts, {
"timestamp": ts,
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
})
# Create unique task key
task_key = (
result["user_prompt"],
result["robot_utterance"],
ts_info["current_skill"],
result["scenario_type"],
result["response_type"],
)
# Assign or create task index
if task_key not in high_level_tasks:
high_level_tasks[task_key] = task_index_counter
task_index_counter += 1
current_task_idx = high_level_tasks[task_key]
# Find all frames at this timestamp and assign task_idx
ep_from = ep["dataset_from_index"]
ep_to = ep["dataset_to_index"]
for frame_idx in range(ep_from, ep_to):
frame = dataset[frame_idx]
frame_ts = frame["timestamp"].item()
# Assign to closest sampled timestamp
if abs(frame_ts - ts) < sample_interval_seconds / 2:
task_indices[frame_idx] = current_task_idx
# Save for debugging
debug_outputs.append({
"episode_id": int(episode_idx),
"timestamp": float(ts),
"skill_current": ts_info["current_skill"],
"skills_so_far": ts_info["skills_so_far"],
"task_description": description,
"video_mode": True,
**result,
})
finally:
# Clean up extracted videos
for extracted_path in extracted_videos:
if extracted_path and extracted_path.exists():
extracted_path.unlink()
console.print(f"[green]✓ Processed {timestamps_processed} timestamps across {len(episodes)} episodes[/green]")
# Create tasks DataFrame
tasks_data = []
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
tasks_data.append({
"task": f"{user_prompt} | {robot_utterance}",
"task_index": task_idx,
"user_prompt": user_prompt,
"robot_utterance": robot_utterance,
"skill": skill,
"scenario_type": scenario_type,
"response_type": response_type,
})
tasks_df = pd.DataFrame(tasks_data).set_index("task")
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
return tasks_df, task_indices, debug_outputs
def generate_synthetic_data(
dataset: LeRobotDataset,
pgen: QwenPgen,
skills_metadata: dict,
image_keys: list[str],
sample_interval_seconds: float = 1.0,
console: Console | None = None,
video_mode: bool = False,
video_key: str | None = None,
video_batch_size: int = 1,
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
"""
Generate synthetic dialogue data for entire dataset.
This function processes ALL frames in the dataset, but only calls the VLM
at specified intervals (sample_interval_seconds). Frames between samples
inherit the task_index from the most recent sample.
Args:
dataset: LeRobot dataset with skill annotations
pgen: Qwen model wrapper
skills_metadata: Loaded skills.json metadata
image_keys: List of image observation keys to use (for image mode)
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
console: Rich console for logging
video_mode: If True, use video clips instead of sampled images
video_key: Video observation key for video mode (e.g., 'observation.images.base')
video_batch_size: Number of episodes to process in each VLM batch (video mode only)
Returns:
Tuple of (tasks_df, task_indices_array, debug_outputs)
- tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.)
- task_indices_array: Array of task indices for each frame (full dataset length)
- debug_outputs: List of debug dictionaries (only for sampled frames)
"""
if console is None:
console = Console()
# Extract metadata
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
episodes = skills_metadata.get("episodes", {})
# Track unique high-level tasks
high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index
task_index_counter = 0 # Start at 0
# Array to store task index for each frame - MUST match full dataset length
full_dataset_length = len(dataset)
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
# For debugging - save to JSONL
debug_outputs = []
# Initialize video extractor if in video mode
video_extractor = VideoExtractor(console) if video_mode else None
if video_mode:
if video_key is None:
raise ValueError("video_key must be provided when video_mode=True")
console.print(f"[cyan]Using VIDEO MODE with video key: {video_key}[/cyan]")
console.print(f"[cyan]Video batch size: {video_batch_size} episodes per VLM call[/cyan]")
# In video mode, process episodes in batches with full videos
return _generate_synthetic_data_video_mode(
dataset=dataset,
pgen=pgen,
skills_metadata=skills_metadata,
video_key=video_key,
video_extractor=video_extractor,
console=console,
sample_interval_seconds=sample_interval_seconds,
batch_size=video_batch_size,
)
# IMAGE MODE (original logic)
# Track sampling
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
last_task_index = {} # episode_idx -> last generated task_index
frames_sampled = 0
console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]")
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]")
# Process each frame in the FULL dataset
for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"):
try:
# Get frame data
frame = dataset[frame_idx]
episode_idx = frame["episode_index"].item()
timestamp = frame["timestamp"].item()
# Get episode skills
episode_key = str(episode_idx)
if episode_key not in episodes:
console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]")
continue
episode_data = episodes[episode_key]
skills = episode_data.get("skills", [])
description = episode_data.get("description", coarse_description)
# Find current skill
current_skill = get_skill_at_timestamp(skills, timestamp)
if current_skill is None:
console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]")
continue
# Determine if we should sample this frame
should_sample = False
# Always sample first frame of an episode
if episode_idx not in last_sample_timestamp:
should_sample = True
last_sample_timestamp[episode_idx] = timestamp
else:
# Sample if enough time has passed
time_since_last = timestamp - last_sample_timestamp[episode_idx]
if time_since_last >= sample_interval_seconds:
should_sample = True
last_sample_timestamp[episode_idx] = timestamp
# If not sampling, reuse last task index for this episode
if not should_sample:
if episode_idx in last_task_index:
task_indices[frame_idx] = last_task_index[episode_idx]
continue
# Sample this frame - generate synthetic dialogue
frames_sampled += 1
# Build skill history (all skills before current timestamp)
skill_history = []
for skill in skills:
if skill["end"] <= timestamp:
skill_history.append(skill["name"])
# Load images
images = []
for img_key in image_keys:
if img_key in frame:
# Frame images are tensors (C, H, W) in [0, 1]
img_tensor = frame[img_key]
if len(img_tensor.shape) == 4: # (T, C, H, W)
img_tensor = img_tensor[-1] # Take last frame
# Convert to PIL Image
img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
img_pil = Image.fromarray(img_array)
images.append(img_pil)
if not images:
console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]")
continue
# Generate synthetic dialogue
result = annotate_sample_image(
pgen=pgen,
images=images,
task_description=description,
skill_history=skill_history,
skill_current=current_skill,
)
# Create unique task key
task_key = (
result["user_prompt"],
result["robot_utterance"],
current_skill,
result["scenario_type"],
result["response_type"],
)
# Assign or create task index
if task_key not in high_level_tasks:
high_level_tasks[task_key] = task_index_counter
task_index_counter += 1
current_task_idx = high_level_tasks[task_key]
task_indices[frame_idx] = current_task_idx
last_task_index[episode_idx] = current_task_idx
# Save for debugging
debug_outputs.append({
"episode_id": int(episode_idx),
"frame_index": frame_idx,
"timestamp": float(timestamp),
"skill_current": current_skill,
"skill_history": skill_history,
"task_description": description,
"sampled": True,
**result,
})
except Exception as e:
console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]")
continue
console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]")
# Create tasks DataFrame
tasks_data = []
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
tasks_data.append({
"task": f"{user_prompt} | {robot_utterance}",
"task_index": task_idx,
"user_prompt": user_prompt,
"robot_utterance": robot_utterance,
"skill": skill,
"scenario_type": scenario_type,
"response_type": response_type,
})
tasks_df = pd.DataFrame(tasks_data).set_index("task")
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
return tasks_df, task_indices, debug_outputs
def save_high_level_tasks(
tasks_df: pd.DataFrame,
dataset_root: Path,
console: Console | None = None,
) -> None:
"""Save high-level tasks to tasks_high_level.parquet."""
if console is None:
console = Console()
output_path = dataset_root / "meta" / "tasks_high_level.parquet"
output_path.parent.mkdir(parents=True, exist_ok=True)
tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]")
def save_debug_outputs(
debug_outputs: list[dict],
dataset_root: Path,
console: Console | None = None,
) -> None:
"""Save debug outputs to JSONL file."""
if console is None:
console = Console()
output_path = dataset_root / "meta" / "syn_annotations.jsonl"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
for item in debug_outputs:
f.write(json.dumps(item) + "\n")
console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]")
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
"""Main entry point for synthetic data generation."""
parser = argparse.ArgumentParser(
description="Generate synthetic dialogue data for hierarchical robot policies",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent("""\
Examples:
# Generate synthetic data for a dataset (image mode)
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen2-VL-7B-Instruct \\
--output-dir ./output
# Use video mode with batching (passes full episode videos)
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen2-VL-7B-Instruct \\
--video-mode \\
--video-key observation.images.base \\
--video-batch-size 4
# Use Qwen3 model with custom parameters
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
--temperature 0.8 \\
--batch-size 1
"""),
)
# Data source
data_group = parser.add_mutually_exclusive_group(required=True)
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-7B-Instruct",
help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run model on (default: cuda)",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Model dtype (default: bfloat16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature (default: 0.7)",
)
# Processing options
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for processing (default: 1) [currently unused]",
)
parser.add_argument(
"--num-image-views-per-sample",
type=int,
default=1,
help="Number of camera views to use per sample (default: 1)",
)
parser.add_argument(
"--sample-interval",
type=float,
default=1.0,
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
)
parser.add_argument(
"--video-mode",
action="store_true",
help="Use video input instead of sampled image frames. Passes entire skill video clips to the model.",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Video observation key for video mode (e.g., 'observation.images.base'). "
"If not specified, uses the first available video key.",
)
parser.add_argument(
"--video-batch-size",
type=int,
default=1,
help="Number of episodes to process in each VLM batch call in video mode (default: 1)",
)
# Output options
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for modified dataset",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push modified dataset to HuggingFace Hub",
)
args = parser.parse_args()
console = Console()
# Load dataset
console.print("[cyan]Loading dataset...[/cyan]")
if args.data_dir:
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir)
dataset_root = Path(args.data_dir)
else:
dataset = LeRobotDataset(repo_id=args.repo_id)
dataset_root = dataset.root
console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]")
# Load skills metadata
console.print("[cyan]Loading skills metadata...[/cyan]")
skills_metadata = load_skills_metadata(dataset_root)
if skills_metadata is None:
console.print("[red]Error: No skills.json found. Run annotate.py first![/red]")
return
console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]")
# Initialize model
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map[args.dtype]
console.print(f"[cyan]Initializing {args.model}...[/cyan]")
pgen = QwenPgen(
model_name=args.model,
device=args.device,
torch_dtype=torch_dtype,
temperature=args.temperature,
)
# Get image keys (for image mode)
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
if not args.video_mode:
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
# Determine video key for video mode
video_key = None
if args.video_mode:
if args.video_key:
# Use explicitly provided video key
video_key = args.video_key
if video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{video_key}' not found in dataset.[/red]")
console.print(f"[yellow]Available video keys: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
elif dataset.meta.video_keys:
# Use first available video key
video_key = dataset.meta.video_keys[0]
else:
console.print("[red]Error: No video keys found in dataset. Cannot use video mode.[/red]")
return
console.print(f"[cyan]Using video key for video mode: {video_key}[/cyan]")
# Generate synthetic data
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
dataset=dataset,
pgen=pgen,
skills_metadata=skills_metadata,
image_keys=image_keys,
sample_interval_seconds=args.sample_interval,
console=console,
video_mode=args.video_mode,
video_key=video_key,
video_batch_size=args.video_batch_size,
)
# Save high-level tasks
save_high_level_tasks(tasks_df, dataset_root, console)
save_debug_outputs(debug_outputs, dataset_root, console)
# Add task_index_high_level feature to dataset
console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]")
# Determine output directory
if args.output_dir:
output_dir = Path(args.output_dir)
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
else:
output_dir = None
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
# Add feature using dataset_tools
feature_info = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
new_dataset = add_features(
dataset=dataset,
features={
"task_index_high_level": (task_indices, feature_info),
},
output_dir=output_dir,
repo_id=repo_id,
)
# copy high level tsk parquet to new output directory
import shutil
shutil.copy(dataset_root / "meta" / "tasks_high_level.parquet", output_dir / "meta" / "tasks_high_level.parquet")
shutil.copy(dataset_root / "meta" / "syn_annotations.jsonl", output_dir / "meta" / "syn_annotations.jsonl")
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
console.print(f" New dataset saved to: {new_dataset.root}")
console.print(f" Total high-level tasks: {len(tasks_df)}")
# Push to hub if requested
if args.push_to_hub:
if args.data_dir:
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
else:
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
try:
new_dataset.push_to_hub()
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")
if __name__ == "__main__":
main()