Files
lerobot/examples/dataset/annotate_pgen.py
T
Jade Choghari c8eee4ea16 add step2
2025-12-09 12:28:46 +00:00

757 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Synthetic Data Generation for Hi-Robot Style Hierarchical Policy Training.
This script generates synthetic user prompts (_t) and robot utterances (u_t) for
hierarchical policy training using Qwen VLM as the generator model (pgen).
The pipeline:
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
2. For each frame, generates synthetic dialogue based on:
- Visual context (images at time t)
- Current skill being performed
- History of previous skills
- High-level task description
3. Saves results as high-level tasks and updates dataset with task_index_high_level
Usage:
```bash
python examples/dataset/annotate_pgen.py \
--repo-id lerobot/svla_so101_pickplace \
--model Qwen/Qwen2-VL-7B-Instruct \
--output-dir /path/to/output \
--batch-size 1
```
"""
import argparse
import json
import re
import textwrap
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from PIL import Image
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from tqdm import tqdm
from lerobot.datasets.dataset_tools import add_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# =============================================================================
# Prompt Template for pgen
# =============================================================================
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
# Role
You are a robot-assistant dialogue generator for hierarchical robot policies.
# Task
You will receive:
- A list of images showing the current robot scene at time t
- The high-level task: {task_description}
- Previous skill steps completed: {skill_history}
- The next skill to be performed by the robot: {skill_current}
# Your Goal
Generate two things that create a natural human-robot interaction:
1. **user_prompt**: A natural-sounding user request that logically leads to the robot
performing the skill "{skill_current}" given the task context and history.
2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request.
# Guidelines
- The user prompt should be grounded in the visual scene and task context
- Vary interaction types: direct commands, implicit requests, corrections, constraints
- Examples of user prompt styles:
* Direct: "Can you pick up the red brick?"
* Implicit: "I need something red for the tower"
* Negative: "Don't pick up the blue one"
* Constraint: "Make sure to handle it gently"
* Correction: "Actually, move to the other box instead"
- Robot responses should be appropriate: confirmations, clarifications, or error handling
- Use the skill history to ensure continuity (don't repeat past actions)
- Consider world knowledge (dietary preferences, object properties, etc.)
# Scenario Types (choose one that fits):
- **specific_object**: User specifies exact object/action
- **negative_task**: User says what NOT to do
- **situated_correction**: User adjusts based on current state
- **implicit_request**: User implies need without direct command
- **constraint_based**: User adds specific constraints
# Response Types (choose one that fits):
- **confirmation**: Simple "OK, I'll do X"
- **clarification**: "Just to confirm, you want me to..."
- **acknowledgment**: "Got it, [doing action]"
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
# Output Format
Respond ONLY with valid JSON:
{{
"scenario_type": "one of the types above",
"response_type": "one of the types above",
"user_prompt": "natural user request here",
"robot_utterance": "natural robot response here"
}}
The responses must be grounded in the visual scene, the task, and the skill history.
Make it sound like a real human-robot interaction.
""")
def construct_prompt(
task_description: str,
skill_history: list[str],
skill_current: str,
) -> str:
"""
Construct the text prompt for pgen.
Args:
task_description: High-level task description
skill_history: List of previously completed skills
skill_current: Current skill to be performed
Returns:
Formatted prompt string
"""
# Format skill history nicely
if skill_history:
history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context
if len(skill_history) > 5:
history_str = f"... {history_str}"
else:
history_str = "None (starting the task)"
return PGEN_PROMPT_TEMPLATE.format(
task_description=task_description,
skill_history=history_str,
skill_current=skill_current,
)
# =============================================================================
# Qwen VLM Interface
# =============================================================================
class QwenPgen:
"""Qwen VLM wrapper for synthetic dialogue generation."""
def __init__(
self,
model_name: str,
device: str = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
temperature: float = 0.7,
):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
self.console = Console()
self.device = device
self.model_name = model_name
self.temperature = temperature
self.process_vision_info = process_vision_info
self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]")
# Load model based on name
if "qwen3" in model_name.lower():
from transformers import Qwen3VLMoeForConditionalGeneration
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
def call_qwen(
self,
images: list[Image.Image | str],
prompt: str,
) -> dict[str, str]:
"""
Call Qwen VLM to generate synthetic dialogue.
Args:
images: List of PIL Images or image paths
prompt: Text prompt for generation
Returns:
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
"""
# Build messages with images and text
content = []
for img in images:
if isinstance(img, str):
content.append({"type": "image", "image": img})
else:
# PIL Image - need to save temporarily or convert
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": prompt})
messages = [
{
"role": "user",
"content": content,
}
]
# Process inputs
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
# Generate
with torch.no_grad():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=self.temperature,
)
# Decode response
response = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
skip_special_tokens=True,
)[0].strip()
return self._parse_response(response)
def _parse_response(self, response: str) -> dict[str, str]:
"""Parse JSON response from model."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
return {
"scenario_type": data.get("scenario_type", "specific_object"),
"response_type": data.get("response_type", "confirmation"),
"user_prompt": data.get("user_prompt", ""),
"robot_utterance": data.get("robot_utterance", ""),
}
except json.JSONDecodeError:
# Try to find JSON object in response
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
return {
"scenario_type": data.get("scenario_type", "specific_object"),
"response_type": data.get("response_type", "confirmation"),
"user_prompt": data.get("user_prompt", ""),
"robot_utterance": data.get("robot_utterance", ""),
}
raise ValueError(f"Could not parse response: {response[:200]}...")
# =============================================================================
# Annotation Pipeline
# =============================================================================
def load_skills_metadata(dataset_root: Path) -> dict | None:
"""Load skills.json metadata from annotated dataset."""
skills_path = dataset_root / "meta" / "skills.json"
if skills_path.exists():
with open(skills_path) as f:
return json.load(f)
return None
def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
"""Find which skill covers a given timestamp."""
for skill in skills:
if skill["start"] <= timestamp < skill["end"]:
return skill["name"]
# Handle last frame
if timestamp >= skill["end"] and skill == skills[-1]:
return skill["name"]
return skills[-1]["name"] if skills else None
def annotate_sample(
pgen: QwenPgen,
images: list[Image.Image | str],
task_description: str,
skill_history: list[str],
skill_current: str,
) -> dict[str, str]:
"""
Generate synthetic dialogue for a single sample.
Args:
pgen: Qwen model wrapper
images: List of images at current timestep
task_description: High-level task description
skill_history: Previous skills completed
skill_current: Current skill being performed
Returns:
Dictionary with generated dialogue
"""
prompt = construct_prompt(task_description, skill_history, skill_current)
result = pgen.call_qwen(images, prompt)
return result
def generate_synthetic_data(
dataset: LeRobotDataset,
pgen: QwenPgen,
skills_metadata: dict,
image_keys: list[str],
sample_interval_seconds: float = 1.0,
console: Console | None = None,
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
"""
Generate synthetic dialogue data for entire dataset.
This function processes ALL frames in the dataset, but only calls the VLM
at specified intervals (sample_interval_seconds). Frames between samples
inherit the task_index from the most recent sample.
Args:
dataset: LeRobot dataset with skill annotations
pgen: Qwen model wrapper
skills_metadata: Loaded skills.json metadata
image_keys: List of image observation keys to use
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
console: Rich console for logging
Returns:
Tuple of (tasks_df, task_indices_array, debug_outputs)
- tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.)
- task_indices_array: Array of task indices for each frame (full dataset length)
- debug_outputs: List of debug dictionaries (only for sampled frames)
"""
if console is None:
console = Console()
# Extract metadata
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
episodes = skills_metadata.get("episodes", {})
# Track unique high-level tasks
high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index
task_index_counter = 0 # Start at 0
# Array to store task index for each frame - MUST match full dataset length
full_dataset_length = len(dataset)
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
# For debugging - save to JSONL
debug_outputs = []
# Track sampling
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
last_task_index = {} # episode_idx -> last generated task_index
frames_sampled = 0
console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]")
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]")
# Process each frame in the FULL dataset
for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"):
try:
# Get frame data
frame = dataset[frame_idx]
episode_idx = frame["episode_index"].item()
timestamp = frame["timestamp"].item()
# Get episode skills
episode_key = str(episode_idx)
if episode_key not in episodes:
console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]")
continue
episode_data = episodes[episode_key]
skills = episode_data.get("skills", [])
description = episode_data.get("description", coarse_description)
# Find current skill
current_skill = get_skill_at_timestamp(skills, timestamp)
if current_skill is None:
console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]")
continue
# Determine if we should sample this frame
should_sample = False
# Always sample first frame of an episode
if episode_idx not in last_sample_timestamp:
should_sample = True
last_sample_timestamp[episode_idx] = timestamp
else:
# Sample if enough time has passed
time_since_last = timestamp - last_sample_timestamp[episode_idx]
if time_since_last >= sample_interval_seconds:
should_sample = True
last_sample_timestamp[episode_idx] = timestamp
# If not sampling, reuse last task index for this episode
if not should_sample:
if episode_idx in last_task_index:
task_indices[frame_idx] = last_task_index[episode_idx]
continue
# Sample this frame - generate synthetic dialogue
frames_sampled += 1
# Build skill history (all skills before current timestamp)
skill_history = []
for skill in skills:
if skill["end"] <= timestamp:
skill_history.append(skill["name"])
# Load images
images = []
for img_key in image_keys:
if img_key in frame:
# Frame images are tensors (C, H, W) in [0, 1]
img_tensor = frame[img_key]
if len(img_tensor.shape) == 4: # (T, C, H, W)
img_tensor = img_tensor[-1] # Take last frame
# Convert to PIL Image
img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
img_pil = Image.fromarray(img_array)
images.append(img_pil)
if not images:
console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]")
continue
# Generate synthetic dialogue
result = annotate_sample(
pgen=pgen,
images=images,
task_description=description,
skill_history=skill_history,
skill_current=current_skill,
)
# Create unique task key
task_key = (
result["user_prompt"],
result["robot_utterance"],
current_skill,
result["scenario_type"],
result["response_type"],
)
# Assign or create task index
if task_key not in high_level_tasks:
high_level_tasks[task_key] = task_index_counter
task_index_counter += 1
current_task_idx = high_level_tasks[task_key]
task_indices[frame_idx] = current_task_idx
last_task_index[episode_idx] = current_task_idx
# Save for debugging
debug_outputs.append({
"episode_id": int(episode_idx),
"frame_index": frame_idx,
"timestamp": float(timestamp),
"skill_current": current_skill,
"skill_history": skill_history,
"task_description": description,
"sampled": True,
**result,
})
except Exception as e:
console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]")
continue
console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]")
# Create tasks DataFrame
tasks_data = []
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
tasks_data.append({
"task": f"{user_prompt} | {robot_utterance}",
"task_index": task_idx,
"user_prompt": user_prompt,
"robot_utterance": robot_utterance,
"skill": skill,
"scenario_type": scenario_type,
"response_type": response_type,
})
tasks_df = pd.DataFrame(tasks_data).set_index("task")
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
return tasks_df, task_indices, debug_outputs
def save_high_level_tasks(
tasks_df: pd.DataFrame,
dataset_root: Path,
console: Console | None = None,
) -> None:
"""Save high-level tasks to tasks_high_level.parquet."""
if console is None:
console = Console()
output_path = dataset_root / "meta" / "tasks_high_level.parquet"
output_path.parent.mkdir(parents=True, exist_ok=True)
tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]")
def save_debug_outputs(
debug_outputs: list[dict],
dataset_root: Path,
console: Console | None = None,
) -> None:
"""Save debug outputs to JSONL file."""
if console is None:
console = Console()
output_path = dataset_root / "meta" / "syn_annotations.jsonl"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
for item in debug_outputs:
f.write(json.dumps(item) + "\n")
console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]")
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
"""Main entry point for synthetic data generation."""
parser = argparse.ArgumentParser(
description="Generate synthetic dialogue data for hierarchical robot policies",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=textwrap.dedent("""\
Examples:
# Generate synthetic data for a dataset
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen2-VL-7B-Instruct \\
--output-dir ./output
# Use Qwen3 model with custom parameters
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
--temperature 0.8 \\
--batch-size 1
"""),
)
# Data source
data_group = parser.add_mutually_exclusive_group(required=True)
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
# Model configuration
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-7B-Instruct",
help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run model on (default: cuda)",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
help="Model dtype (default: bfloat16)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature (default: 0.7)",
)
# Processing options
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for processing (default: 1) [currently unused]",
)
parser.add_argument(
"--num-image-views-per-sample",
type=int,
default=1,
help="Number of camera views to use per sample (default: 1)",
)
parser.add_argument(
"--sample-interval",
type=float,
default=1.0,
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
)
# Output options
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory for modified dataset",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push modified dataset to HuggingFace Hub",
)
args = parser.parse_args()
console = Console()
# Load dataset
console.print("[cyan]Loading dataset...[/cyan]")
if args.data_dir:
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir)
dataset_root = Path(args.data_dir)
else:
dataset = LeRobotDataset(repo_id=args.repo_id)
dataset_root = dataset.root
console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]")
# Load skills metadata
console.print("[cyan]Loading skills metadata...[/cyan]")
skills_metadata = load_skills_metadata(dataset_root)
if skills_metadata is None:
console.print("[red]Error: No skills.json found. Run annotate.py first![/red]")
return
console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]")
# Initialize model
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
torch_dtype = dtype_map[args.dtype]
console.print(f"[cyan]Initializing {args.model}...[/cyan]")
pgen = QwenPgen(
model_name=args.model,
device=args.device,
torch_dtype=torch_dtype,
temperature=args.temperature,
)
# Get image keys
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
# Generate synthetic data
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
dataset=dataset,
pgen=pgen,
skills_metadata=skills_metadata,
image_keys=image_keys,
sample_interval_seconds=args.sample_interval,
console=console,
)
# Save high-level tasks
save_high_level_tasks(tasks_df, dataset_root, console)
save_debug_outputs(debug_outputs, dataset_root, console)
# Add task_index_high_level feature to dataset
console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]")
# Determine output directory
if args.output_dir:
output_dir = Path(args.output_dir)
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
else:
output_dir = None
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
# Add feature using dataset_tools
feature_info = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
breakpoint()
new_dataset = add_features(
dataset=dataset,
features={
"task_index_high_level": (task_indices, feature_info),
},
output_dir=output_dir,
repo_id=repo_id,
)
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
console.print(f" New dataset saved to: {new_dataset.root}")
console.print(f" Total high-level tasks: {len(tasks_df)}")
# Push to hub if requested
if args.push_to_hub:
if args.data_dir:
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
else:
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
try:
new_dataset.push_to_hub(push_videos=False)
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
except Exception as e:
console.print(f"[red]Push failed: {e}[/red]")
if __name__ == "__main__":
main()