allow loading high level tasks

This commit is contained in:
Jade Choghari
2025-12-10 16:22:54 +00:00
parent 8edbd5b55e
commit 3c11946755
6 changed files with 770 additions and 56 deletions
+1 -1
View File
@@ -513,7 +513,7 @@ class Qwen3VL(BaseVLM):
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
breakpoint()
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
+722 -54
View File
@@ -23,29 +23,44 @@ hierarchical policy training using Qwen VLM as the generator model (pgen).
The pipeline:
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
2. For each frame, generates synthetic dialogue based on:
- Visual context (images at time t)
- Visual context (images at time t OR video clips in video mode)
- Current skill being performed
- History of previous skills
- High-level task description
3. Saves results as high-level tasks and updates dataset with task_index_high_level
Modes:
- Image Mode (default): Samples frames at intervals and sends images to the model
- Video Mode (--video-mode): Passes entire skill video clips to the model
Usage:
```bash
# Image mode (default)
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--output-dir /path/to/output \
--batch-size 1
--output-dir /path/to/output
# Video mode with batch processing
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--video-mode \
--video-key observation.images.base \
--video-batch-size 4
```
"""
import argparse
import json
import re
import subprocess
import tempfile
import textwrap
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import pandas as pd
import torch
@@ -58,11 +73,91 @@ from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# =============================================================================
# Video Extraction Utilities
# =============================================================================
class VideoExtractor:
"""Utilities for extracting and processing video segments from LeRobot datasets."""
def __init__(self, console: Console | None = None):
self.console = console or Console()
def extract_episode_video(
self,
video_path: Path,
start_timestamp: float,
end_timestamp: float,
target_fps: int = 1,
) -> Path:
"""
Extract a specific episode segment from a concatenated video file.
Args:
video_path: Path to the source video file
start_timestamp: Start time in seconds
end_timestamp: End time in seconds
target_fps: Target frames per second for output
Returns:
Path to the extracted temporary video file
"""
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
tmp_path = Path(tmp_file.name)
tmp_file.close()
duration = end_timestamp - start_timestamp
cmd = [
"ffmpeg",
"-i",
str(video_path),
"-ss",
str(start_timestamp),
"-t",
str(duration),
"-r",
str(target_fps),
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-crf",
"23",
"-an",
"-y",
str(tmp_path),
]
try:
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"FFmpeg failed: {e}") from e
except FileNotFoundError:
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
if tmp_path.exists():
tmp_path.unlink()
raise RuntimeError("Video extraction produced invalid file")
return tmp_path
def get_video_duration(self, video_path: Path) -> float:
"""Get duration of a video file in seconds."""
cap = cv2.VideoCapture(str(video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return frame_count / fps
# =============================================================================
# Prompt Template for pgen
# =============================================================================
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
PGEN_PROMPT_TEMPLATE_IMAGE = textwrap.dedent("""\
# Role
You are a robot-assistant dialogue generator for hierarchical robot policies.
@@ -118,14 +213,76 @@ PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
Make it sound like a real human-robot interaction.
""")
PGEN_PROMPT_TEMPLATE_VIDEO = textwrap.dedent("""\
# Role
You are a robot-assistant dialogue generator for hierarchical robot policies.
# Task
You are watching a full robot demonstration video for the task: {task_description}
For each timestamp below, generate natural human-robot dialogue that would have led to the observed behavior.
At each timestamp, you'll see:
- What skills have been completed so far (cumulative history)
- What skill is currently being executed
{timestamp_context}
# Your Goal
For EACH timestamp, generate:
1. **user_prompt**: A natural user request that would lead to the robot performing the current skill
2. **robot_utterance**: A natural robot response acknowledging the request
# Guidelines
- Watch the video from start to each timestamp to understand the context
- Ground prompts in what's visible in the video at that time
- Vary interaction types: direct commands, implicit requests, corrections, constraints
- Examples of user prompt styles:
* Direct: "Can you pick up the red brick?"
* Implicit: "I need something red for the tower"
* Negative: "Don't pick up the blue one"
* Constraint: "Make sure to handle it gently"
* Correction: "Actually, move to the other box instead"
- Robot responses should be appropriate: confirmations, clarifications, or error handling
- Ensure continuity across timestamps (don't contradict earlier dialogue)
- Consider world knowledge (dietary preferences, object properties, etc.)
# Scenario Types:
- **specific_object**: User specifies exact object/action
- **negative_task**: User says what NOT to do
- **situated_correction**: User adjusts based on current state
- **implicit_request**: User implies need without direct command
- **constraint_based**: User adds specific constraints
# Response Types:
- **confirmation**: Simple "OK, I'll do X"
- **clarification**: "Just to confirm, you want me to..."
- **acknowledgment**: "Got it, [doing action]"
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
# Output Format
Respond ONLY with valid JSON array:
[
{{
"timestamp": timestamp_value,
"scenario_type": "one of the types above",
"response_type": "one of the types above",
"user_prompt": "natural user request here",
"robot_utterance": "natural robot response here"
}},
... (one entry per timestamp)
]
Make it sound like a real human-robot interaction grounded in the video.
""")
def construct_prompt(
def construct_prompt_image(
task_description: str,
skill_history: list[str],
skill_current: str,
) -> str:
"""
Construct the text prompt for pgen.
Construct the text prompt for pgen in image mode.
Args:
task_description: High-level task description
@@ -143,13 +300,54 @@ def construct_prompt(
else:
history_str = "None (starting the task)"
return PGEN_PROMPT_TEMPLATE.format(
return PGEN_PROMPT_TEMPLATE_IMAGE.format(
task_description=task_description,
skill_history=history_str,
skill_current=skill_current,
)
def construct_prompt_video(
task_description: str,
timestamps_with_skills: list[dict],
) -> str:
"""
Construct the text prompt for pgen in video mode.
Args:
task_description: High-level task description
timestamps_with_skills: List of dicts with keys:
- timestamp: float
- skills_so_far: list[str]
- current_skill: str
Returns:
Formatted prompt string
"""
# Build timestamp context
timestamp_lines = []
for item in timestamps_with_skills:
ts = item["timestamp"]
skills_so_far = item["skills_so_far"]
current_skill = item["current_skill"]
if skills_so_far:
skills_str = ", ".join(f'"{s}"' for s in skills_so_far)
else:
skills_str = "None (starting)"
timestamp_lines.append(
f"- **Timestamp {ts:.2f}s**: Skills completed: [{skills_str}] | Current skill: \"{current_skill}\""
)
timestamp_context = "\n".join(timestamp_lines)
return PGEN_PROMPT_TEMPLATE_VIDEO.format(
task_description=task_description,
timestamp_context=timestamp_context,
)
# =============================================================================
# Qwen VLM Interface
# =============================================================================
@@ -191,40 +389,54 @@ class QwenPgen:
def call_qwen(
self,
images: list[Image.Image | str],
prompt: str,
images: list[Image.Image | str] | None = None,
prompt: str = "",
video: str | Path | None = None,
) -> dict[str, str]:
"""
Call Qwen VLM to generate synthetic dialogue for a single request.
Args:
images: List of PIL Images or image paths
images: List of PIL Images or image paths (for image mode)
prompt: Text prompt for generation
video: Path to video file (for video mode)
Returns:
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
# Use batch method with single item
results = self.call_qwen_batch([images], [prompt])
results = self.call_qwen_batch(
batch_images=[images] if images else [None],
batch_prompts=[prompt],
batch_videos=[video] if video else [None],
)
return results[0]
def call_qwen_batch(
self,
batch_images: list[list[Image.Image | str]],
batch_images: list[list[Image.Image | str] | None],
batch_prompts: list[str],
batch_videos: list[str | Path | None] | None = None,
) -> list[dict[str, str]]:
"""
Call Qwen VLM to generate synthetic dialogue for a batch of requests.
Args:
batch_images: List of image lists, one per request
batch_images: List of image lists, one per request (None for video mode)
batch_prompts: List of text prompts, one per request
batch_videos: List of video paths, one per request (None for image mode)
Returns:
List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
if len(batch_images) != len(batch_prompts):
raise ValueError(f"Batch size mismatch: {len(batch_images)} image lists vs {len(batch_prompts)} prompts")
if batch_videos is None:
batch_videos = [None] * len(batch_images)
if len(batch_images) != len(batch_prompts) or len(batch_images) != len(batch_videos):
raise ValueError(
f"Batch size mismatch: {len(batch_images)} image lists vs "
f"{len(batch_prompts)} prompts vs {len(batch_videos)} videos"
)
batch_size = len(batch_images)
if batch_size == 0:
@@ -232,14 +444,21 @@ class QwenPgen:
# Build messages for each item in batch
all_messages = []
for images, prompt in zip(batch_images, batch_prompts):
for images, prompt, video in zip(batch_images, batch_prompts, batch_videos):
content = []
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image
content.append({"type": "image", "image": img})
# Add video or images
if video is not None:
# Video mode
content.append({"type": "video", "video": str(video), "fps": 1.0})
elif images is not None:
# Image mode
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
@@ -381,7 +600,7 @@ def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
return skills[-1]["name"] if skills else None
def annotate_sample(
def annotate_sample_image(
pgen: QwenPgen,
images: list[Image.Image | str],
task_description: str,
@@ -389,7 +608,7 @@ def annotate_sample(
skill_current: str,
) -> dict[str, str]:
"""
Generate synthetic dialogue for a single sample.
Generate synthetic dialogue for a single sample using images.
Args:
pgen: Qwen model wrapper
@@ -401,42 +620,418 @@ def annotate_sample(
Returns:
Dictionary with generated dialogue
"""
prompt = construct_prompt(task_description, skill_history, skill_current)
result = pgen.call_qwen(images, prompt)
prompt = construct_prompt_image(task_description, skill_history, skill_current)
result = pgen.call_qwen(images=images, prompt=prompt, video=None)
return result
def annotate_samples_batch(
def annotate_episode_video(
pgen: QwenPgen,
batch_images: list[list[Image.Image | str]],
batch_task_descriptions: list[str],
batch_skill_histories: list[list[str]],
batch_skill_currents: list[str],
) -> list[dict[str, str]]:
video: str | Path,
task_description: str,
timestamps_with_skills: list[dict],
) -> list[dict[str, Any]]:
"""
Generate synthetic dialogue for a batch of samples.
Generate synthetic dialogue for an entire episode using video.
Args:
pgen: Qwen model wrapper
batch_images: List of image lists, one per sample
batch_task_descriptions: List of task descriptions
batch_skill_histories: List of skill history lists
batch_skill_currents: List of current skills
video: Path to episode video file
task_description: High-level task description
timestamps_with_skills: List of dicts with timestamp, skills_so_far, current_skill
Returns:
List of dictionaries with generated dialogue
List of dictionaries with generated dialogue, one per timestamp
"""
# Construct prompts for entire batch
batch_prompts = []
for task_desc, skill_hist, skill_curr in zip(
batch_task_descriptions, batch_skill_histories, batch_skill_currents
):
prompt = construct_prompt(task_desc, skill_hist, skill_curr)
batch_prompts.append(prompt)
# Use batch method with single episode
results = annotate_episodes_video_batch(
pgen=pgen,
batch_videos=[video],
batch_task_descriptions=[task_description],
batch_timestamps_with_skills=[timestamps_with_skills],
)
return results[0]
def annotate_episodes_video_batch(
pgen: QwenPgen,
batch_videos: list[str | Path],
batch_task_descriptions: list[str],
batch_timestamps_with_skills: list[list[dict]],
) -> list[list[dict[str, Any]]]:
"""
Generate synthetic dialogue for multiple episodes using videos in batch.
# Process entire batch in one call
results = pgen.call_qwen_batch(batch_images, batch_prompts)
return results
Args:
pgen: Qwen model wrapper
batch_videos: List of paths to episode video files
batch_task_descriptions: List of high-level task descriptions
batch_timestamps_with_skills: List of timestamp lists, one per episode
Returns:
List of result lists, one per episode (each containing dicts with generated dialogue)
"""
batch_size = len(batch_videos)
if batch_size == 0:
return []
# Build messages for each episode
all_messages = []
for video, task_desc, timestamps_with_skills in zip(
batch_videos, batch_task_descriptions, batch_timestamps_with_skills
):
prompt = construct_prompt_video(task_desc, timestamps_with_skills)
content = [
{"type": "video", "video": str(video), "fps": 1.0},
{"type": "text", "text": prompt},
]
messages = [{"role": "user", "content": content}]
all_messages.append(messages)
# Process all episodes through Qwen in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = pgen.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = pgen.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = pgen.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
videos=all_video_inputs if all_video_inputs else None,
padding=True,
return_tensors="pt",
).to(pgen.device)
with torch.no_grad():
generated_ids = pgen.model.generate(
**inputs,
max_new_tokens=2048, # Larger for multiple timestamps per episode
do_sample=True,
temperature=pgen.temperature,
)
responses = pgen.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)
# Parse each response
all_results = []
for response, timestamps_with_skills in zip(responses, batch_timestamps_with_skills):
results = _parse_video_response(response.strip(), timestamps_with_skills)
all_results.append(results)
return all_results
def _parse_video_response(response: str, timestamps_with_skills: list[dict]) -> list[dict[str, Any]]:
"""Parse JSON array response from video mode."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
if not isinstance(data, list):
# If it's a dict with a list inside
if "annotations" in data:
data = data["annotations"]
elif "results" in data:
data = data["results"]
else:
raise ValueError("Expected JSON array or dict with 'annotations'/'results' key")
results = []
for item in data:
results.append({
"timestamp": item.get("timestamp", 0.0),
"scenario_type": item.get("scenario_type", "specific_object"),
"response_type": item.get("response_type", "confirmation"),
"user_prompt": item.get("user_prompt", ""),
"robot_utterance": item.get("robot_utterance", ""),
})
return results
except json.JSONDecodeError:
# Try to find JSON array in response
match = re.search(r"\[.*\]", response, re.DOTALL)
if match:
data = json.loads(match.group())
results = []
for item in data:
results.append({
"timestamp": item.get("timestamp", 0.0),
"scenario_type": item.get("scenario_type", "specific_object"),
"response_type": item.get("response_type", "confirmation"),
"user_prompt": item.get("user_prompt", ""),
"robot_utterance": item.get("robot_utterance", ""),
})
return results
breakpoint()
# Fallback: return empty results for each timestamp
print(f"Warning: Could not parse video response: {response[:200]}...")
return [
{
"timestamp": ts["timestamp"],
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
}
for ts in timestamps_with_skills
]
def _generate_synthetic_data_video_mode(
dataset: LeRobotDataset,
pgen: QwenPgen,
skills_metadata: dict,
video_key: str,
video_extractor: VideoExtractor,
console: Console,
sample_interval_seconds: float = 1.0,
batch_size: int = 1,
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
"""
Generate synthetic dialogue data using video mode with batched VLM calls.
The VLM sees full episode videos and generates dialogue for multiple
timestamps per episode, with cumulative skill history at each timestamp.
Args:
dataset: LeRobot dataset with skill annotations
pgen: Qwen model wrapper
skills_metadata: Loaded skills.json metadata
video_key: Video observation key (e.g., 'observation.images.base')
video_extractor: VideoExtractor instance
console: Rich console for logging
sample_interval_seconds: Sample timestamps at this interval
batch_size: Number of episodes to process in each VLM batch call
Returns:
Tuple of (tasks_df, task_indices_array, debug_outputs)
"""
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
episodes = skills_metadata.get("episodes", {})
# Track unique high-level tasks
high_level_tasks = {}
task_index_counter = 0
# Array to store task index for each frame
full_dataset_length = len(dataset)
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
debug_outputs = []
timestamps_processed = 0
console.print(f"[cyan]Processing {len(episodes)} episodes in VIDEO MODE with batch_size={batch_size}...[/cyan]")
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s[/cyan]")
# Convert episodes dict to list for batching
episode_list = list(episodes.items())
# Process episodes in batches
for batch_start in tqdm(range(0, len(episode_list), batch_size), desc="Processing episode batches"):
batch_end = min(batch_start + batch_size, len(episode_list))
batch_episodes = episode_list[batch_start:batch_end]
# Collect data for this batch
batch_data = []
extracted_videos = []
for episode_key, episode_data in batch_episodes:
episode_idx = int(episode_key)
skills = episode_data.get("skills", [])
description = episode_data.get("description", coarse_description)
if not skills:
console.print(f"[yellow]Warning: Episode {episode_idx} has no skills[/yellow]")
continue
# Get video path and extract full episode
extracted_path = None
try:
video_path = dataset.root / dataset.meta.get_video_file_path(episode_idx, video_key)
if not video_path.exists():
console.print(f"[yellow]Warning: Video not found for episode {episode_idx}[/yellow]")
continue
# Get episode timestamps
ep = dataset.meta.episodes[episode_idx]
episode_start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
episode_end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
duration = episode_end_ts - episode_start_ts
# Extract FULL episode video
extracted_path = video_extractor.extract_episode_video(
video_path, episode_start_ts, episode_end_ts, target_fps=1
)
extracted_videos.append(extracted_path)
except Exception as e:
console.print(f"[yellow]Warning: Failed to extract video for episode {episode_idx}: {e}[/yellow]")
continue
# Build list of timestamps to sample
timestamps_with_skills = []
current_time = 0.0
while current_time <= duration:
# Find which skill is active at this timestamp
current_skill = None
skills_so_far = []
for skill in skills:
if skill["end"] <= current_time:
skills_so_far.append(skill["name"])
elif skill["start"] <= current_time < skill["end"]:
current_skill = skill["name"]
break
elif current_time >= skill["end"] and skill == skills[-1]:
current_skill = skill["name"]
break
if current_skill:
timestamps_with_skills.append({
"timestamp": current_time,
"skills_so_far": skills_so_far.copy(),
"current_skill": current_skill,
})
current_time += sample_interval_seconds
if not timestamps_with_skills:
console.print(f"[yellow]Warning: No valid timestamps for episode {episode_idx}[/yellow]")
continue
# Store batch item
batch_data.append({
"episode_idx": episode_idx,
"episode_metadata": ep,
"video_path": extracted_path,
"task_description": description,
"timestamps_with_skills": timestamps_with_skills,
"skills": skills,
})
if not batch_data:
continue
# BATCHED VLM CALL for all episodes in batch
try:
batch_results = annotate_episodes_video_batch(
pgen=pgen,
batch_videos=[item["video_path"] for item in batch_data],
batch_task_descriptions=[item["task_description"] for item in batch_data],
batch_timestamps_with_skills=[item["timestamps_with_skills"] for item in batch_data],
)
# Process results for each episode in batch
for item, results in zip(batch_data, batch_results):
episode_idx = item["episode_idx"]
ep = item["episode_metadata"]
timestamps_with_skills = item["timestamps_with_skills"]
description = item["task_description"]
timestamps_processed += len(results)
# Map results back to timestamps and create task indices
timestamp_to_result = {}
for result in results:
ts = result["timestamp"]
timestamp_to_result[ts] = result
# Process each sampled timestamp
for ts_info in timestamps_with_skills:
ts = ts_info["timestamp"]
result = timestamp_to_result.get(ts, {
"timestamp": ts,
"scenario_type": "specific_object",
"response_type": "confirmation",
"user_prompt": "",
"robot_utterance": "",
})
# Create unique task key
task_key = (
result["user_prompt"],
result["robot_utterance"],
ts_info["current_skill"],
result["scenario_type"],
result["response_type"],
)
# Assign or create task index
if task_key not in high_level_tasks:
high_level_tasks[task_key] = task_index_counter
task_index_counter += 1
current_task_idx = high_level_tasks[task_key]
# Find all frames at this timestamp and assign task_idx
ep_from = ep["dataset_from_index"]
ep_to = ep["dataset_to_index"]
for frame_idx in range(ep_from, ep_to):
frame = dataset[frame_idx]
frame_ts = frame["timestamp"].item()
# Assign to closest sampled timestamp
if abs(frame_ts - ts) < sample_interval_seconds / 2:
task_indices[frame_idx] = current_task_idx
# Save for debugging
debug_outputs.append({
"episode_id": int(episode_idx),
"timestamp": float(ts),
"skill_current": ts_info["current_skill"],
"skills_so_far": ts_info["skills_so_far"],
"task_description": description,
"video_mode": True,
**result,
})
finally:
# Clean up extracted videos
for extracted_path in extracted_videos:
if extracted_path and extracted_path.exists():
extracted_path.unlink()
console.print(f"[green]✓ Processed {timestamps_processed} timestamps across {len(episodes)} episodes[/green]")
# Create tasks DataFrame
tasks_data = []
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
tasks_data.append({
"task": f"{user_prompt} | {robot_utterance}",
"task_index": task_idx,
"user_prompt": user_prompt,
"robot_utterance": robot_utterance,
"skill": skill,
"scenario_type": scenario_type,
"response_type": response_type,
})
tasks_df = pd.DataFrame(tasks_data).set_index("task")
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
return tasks_df, task_indices, debug_outputs
def generate_synthetic_data(
@@ -446,6 +1041,9 @@ def generate_synthetic_data(
image_keys: list[str],
sample_interval_seconds: float = 1.0,
console: Console | None = None,
video_mode: bool = False,
video_key: str | None = None,
video_batch_size: int = 1,
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
"""
Generate synthetic dialogue data for entire dataset.
@@ -458,9 +1056,12 @@ def generate_synthetic_data(
dataset: LeRobot dataset with skill annotations
pgen: Qwen model wrapper
skills_metadata: Loaded skills.json metadata
image_keys: List of image observation keys to use
image_keys: List of image observation keys to use (for image mode)
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
console: Rich console for logging
video_mode: If True, use video clips instead of sampled images
video_key: Video observation key for video mode (e.g., 'observation.images.base')
video_batch_size: Number of episodes to process in each VLM batch (video mode only)
Returns:
Tuple of (tasks_df, task_indices_array, debug_outputs)
@@ -486,6 +1087,27 @@ def generate_synthetic_data(
# For debugging - save to JSONL
debug_outputs = []
# Initialize video extractor if in video mode
video_extractor = VideoExtractor(console) if video_mode else None
if video_mode:
if video_key is None:
raise ValueError("video_key must be provided when video_mode=True")
console.print(f"[cyan]Using VIDEO MODE with video key: {video_key}[/cyan]")
console.print(f"[cyan]Video batch size: {video_batch_size} episodes per VLM call[/cyan]")
# In video mode, process episodes in batches with full videos
return _generate_synthetic_data_video_mode(
dataset=dataset,
pgen=pgen,
skills_metadata=skills_metadata,
video_key=video_key,
video_extractor=video_extractor,
console=console,
sample_interval_seconds=sample_interval_seconds,
batch_size=video_batch_size,
)
# IMAGE MODE (original logic)
# Track sampling
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
last_task_index = {} # episode_idx -> last generated task_index
@@ -566,7 +1188,7 @@ def generate_synthetic_data(
continue
# Generate synthetic dialogue
result = annotate_sample(
result = annotate_sample_image(
pgen=pgen,
images=images,
task_description=description,
@@ -677,11 +1299,18 @@ def main():
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent("""\
Examples:
# Generate synthetic data for a dataset
# Generate synthetic data for a dataset (image mode)
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen2-VL-7B-Instruct \\
--output-dir ./output
# Use video mode with batching (passes full episode videos)
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen2-VL-7B-Instruct \\
--video-mode \\
--video-key observation.images.base \\
--video-batch-size 4
# Use Qwen3 model with custom parameters
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
@@ -742,6 +1371,24 @@ def main():
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
)
parser.add_argument(
"--video-mode",
action="store_true",
help="Use video input instead of sampled image frames. Passes entire skill video clips to the model.",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Video observation key for video mode (e.g., 'observation.images.base'). "
"If not specified, uses the first available video key.",
)
parser.add_argument(
"--video-batch-size",
type=int,
default=1,
help="Number of episodes to process in each VLM batch call in video mode (default: 1)",
)
# Output options
parser.add_argument(
@@ -795,9 +1442,28 @@ def main():
temperature=args.temperature,
)
# Get image keys
# Get image keys (for image mode)
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
if not args.video_mode:
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
# Determine video key for video mode
video_key = None
if args.video_mode:
if args.video_key:
# Use explicitly provided video key
video_key = args.video_key
if video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{video_key}' not found in dataset.[/red]")
console.print(f"[yellow]Available video keys: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
elif dataset.meta.video_keys:
# Use first available video key
video_key = dataset.meta.video_keys[0]
else:
console.print("[red]Error: No video keys found in dataset. Cannot use video mode.[/red]")
return
console.print(f"[cyan]Using video key for video mode: {video_key}[/cyan]")
# Generate synthetic data
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
@@ -807,6 +1473,9 @@ def main():
image_keys=image_keys,
sample_interval_seconds=args.sample_interval,
console=console,
video_mode=args.video_mode,
video_key=video_key,
video_batch_size=args.video_batch_size,
)
# Save high-level tasks
@@ -830,7 +1499,6 @@ def main():
"shape": (1,),
"names": None,
}
breakpoint()
new_dataset = add_features(
dataset=dataset,
features={
+23
View File
@@ -0,0 +1,23 @@
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=32,
shuffle=True,
)
batch = next(iter(dataloader))
print(batch.keys())
print(batch['task_index_high_level'].shape)
print(batch['task_index_high_level'])
print(batch['user_prompt'][0])
print(batch['robot_utterance'][0])
print(batch['task'][0])
breakpoint()
+12 -1
View File
@@ -9,7 +9,7 @@ MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations"
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
BATCH_SIZE=32
TEMPERATURE=0.9
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
@@ -20,6 +20,7 @@ python examples/dataset/annotate_pgen.py \
--model "$MODEL" \
--output-dir "$OUTPUT_DIR" \
--temperature "$TEMPERATURE" \
--batch-size "$BATCH_SIZE" \
--sample-interval "$SAMPLE_INTERVAL" \
--num-image-views-per-sample 1
@@ -29,3 +30,13 @@ python examples/dataset/annotate_pgen.py \
# To push to hub after generation:
# Add --push-to-hub flag
# Efficient batch processing: 4 episodes at once
# python examples/dataset/annotate_pgen.py \
# --repo-id "$REPO_ID" \
# --model "$MODEL" \
# --output-dir "$OUTPUT_DIR" \
# --video-mode \
# --video-key observation.images.up \
# --video-batch-size "$BATCH_SIZE" \
# --sample-interval 1.0