mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
allow loading high level tasks
This commit is contained in:
@@ -513,7 +513,7 @@ class Qwen3VL(BaseVLM):
|
||||
data = json.loads(match.group())
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
breakpoint()
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
|
||||
@@ -23,29 +23,44 @@ hierarchical policy training using Qwen VLM as the generator model (pgen).
|
||||
The pipeline:
|
||||
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
|
||||
2. For each frame, generates synthetic dialogue based on:
|
||||
- Visual context (images at time t)
|
||||
- Visual context (images at time t OR video clips in video mode)
|
||||
- Current skill being performed
|
||||
- History of previous skills
|
||||
- High-level task description
|
||||
3. Saves results as high-level tasks and updates dataset with task_index_high_level
|
||||
|
||||
Modes:
|
||||
- Image Mode (default): Samples frames at intervals and sends images to the model
|
||||
- Video Mode (--video-mode): Passes entire skill video clips to the model
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
# Image mode (default)
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--output-dir /path/to/output \
|
||||
--batch-size 1
|
||||
--output-dir /path/to/output
|
||||
|
||||
# Video mode with batch processing
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--video-mode \
|
||||
--video-key observation.images.base \
|
||||
--video-batch-size 4
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -58,11 +73,91 @@ from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Video Extraction Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class VideoExtractor:
|
||||
"""Utilities for extracting and processing video segments from LeRobot datasets."""
|
||||
|
||||
def __init__(self, console: Console | None = None):
|
||||
self.console = console or Console()
|
||||
|
||||
def extract_episode_video(
|
||||
self,
|
||||
video_path: Path,
|
||||
start_timestamp: float,
|
||||
end_timestamp: float,
|
||||
target_fps: int = 1,
|
||||
) -> Path:
|
||||
"""
|
||||
Extract a specific episode segment from a concatenated video file.
|
||||
|
||||
Args:
|
||||
video_path: Path to the source video file
|
||||
start_timestamp: Start time in seconds
|
||||
end_timestamp: End time in seconds
|
||||
target_fps: Target frames per second for output
|
||||
|
||||
Returns:
|
||||
Path to the extracted temporary video file
|
||||
"""
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
tmp_file.close()
|
||||
|
||||
duration = end_timestamp - start_timestamp
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-ss",
|
||||
str(start_timestamp),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-r",
|
||||
str(target_fps),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"ultrafast",
|
||||
"-crf",
|
||||
"23",
|
||||
"-an",
|
||||
"-y",
|
||||
str(tmp_path),
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"FFmpeg failed: {e}") from e
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
|
||||
|
||||
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
raise RuntimeError("Video extraction produced invalid file")
|
||||
|
||||
return tmp_path
|
||||
|
||||
def get_video_duration(self, video_path: Path) -> float:
|
||||
"""Get duration of a video file in seconds."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
cap.release()
|
||||
return frame_count / fps
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt Template for pgen
|
||||
# =============================================================================
|
||||
|
||||
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
|
||||
PGEN_PROMPT_TEMPLATE_IMAGE = textwrap.dedent("""\
|
||||
# Role
|
||||
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
||||
|
||||
@@ -118,14 +213,76 @@ PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
|
||||
Make it sound like a real human-robot interaction.
|
||||
""")
|
||||
|
||||
PGEN_PROMPT_TEMPLATE_VIDEO = textwrap.dedent("""\
|
||||
# Role
|
||||
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
||||
|
||||
# Task
|
||||
You are watching a full robot demonstration video for the task: {task_description}
|
||||
|
||||
For each timestamp below, generate natural human-robot dialogue that would have led to the observed behavior.
|
||||
At each timestamp, you'll see:
|
||||
- What skills have been completed so far (cumulative history)
|
||||
- What skill is currently being executed
|
||||
|
||||
{timestamp_context}
|
||||
|
||||
# Your Goal
|
||||
For EACH timestamp, generate:
|
||||
1. **user_prompt**: A natural user request that would lead to the robot performing the current skill
|
||||
2. **robot_utterance**: A natural robot response acknowledging the request
|
||||
|
||||
# Guidelines
|
||||
- Watch the video from start to each timestamp to understand the context
|
||||
- Ground prompts in what's visible in the video at that time
|
||||
- Vary interaction types: direct commands, implicit requests, corrections, constraints
|
||||
- Examples of user prompt styles:
|
||||
* Direct: "Can you pick up the red brick?"
|
||||
* Implicit: "I need something red for the tower"
|
||||
* Negative: "Don't pick up the blue one"
|
||||
* Constraint: "Make sure to handle it gently"
|
||||
* Correction: "Actually, move to the other box instead"
|
||||
- Robot responses should be appropriate: confirmations, clarifications, or error handling
|
||||
- Ensure continuity across timestamps (don't contradict earlier dialogue)
|
||||
- Consider world knowledge (dietary preferences, object properties, etc.)
|
||||
|
||||
# Scenario Types:
|
||||
- **specific_object**: User specifies exact object/action
|
||||
- **negative_task**: User says what NOT to do
|
||||
- **situated_correction**: User adjusts based on current state
|
||||
- **implicit_request**: User implies need without direct command
|
||||
- **constraint_based**: User adds specific constraints
|
||||
|
||||
# Response Types:
|
||||
- **confirmation**: Simple "OK, I'll do X"
|
||||
- **clarification**: "Just to confirm, you want me to..."
|
||||
- **acknowledgment**: "Got it, [doing action]"
|
||||
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
|
||||
|
||||
# Output Format
|
||||
Respond ONLY with valid JSON array:
|
||||
[
|
||||
{{
|
||||
"timestamp": timestamp_value,
|
||||
"scenario_type": "one of the types above",
|
||||
"response_type": "one of the types above",
|
||||
"user_prompt": "natural user request here",
|
||||
"robot_utterance": "natural robot response here"
|
||||
}},
|
||||
... (one entry per timestamp)
|
||||
]
|
||||
|
||||
Make it sound like a real human-robot interaction grounded in the video.
|
||||
""")
|
||||
|
||||
def construct_prompt(
|
||||
|
||||
def construct_prompt_image(
|
||||
task_description: str,
|
||||
skill_history: list[str],
|
||||
skill_current: str,
|
||||
) -> str:
|
||||
"""
|
||||
Construct the text prompt for pgen.
|
||||
Construct the text prompt for pgen in image mode.
|
||||
|
||||
Args:
|
||||
task_description: High-level task description
|
||||
@@ -143,13 +300,54 @@ def construct_prompt(
|
||||
else:
|
||||
history_str = "None (starting the task)"
|
||||
|
||||
return PGEN_PROMPT_TEMPLATE.format(
|
||||
return PGEN_PROMPT_TEMPLATE_IMAGE.format(
|
||||
task_description=task_description,
|
||||
skill_history=history_str,
|
||||
skill_current=skill_current,
|
||||
)
|
||||
|
||||
|
||||
def construct_prompt_video(
|
||||
task_description: str,
|
||||
timestamps_with_skills: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Construct the text prompt for pgen in video mode.
|
||||
|
||||
Args:
|
||||
task_description: High-level task description
|
||||
timestamps_with_skills: List of dicts with keys:
|
||||
- timestamp: float
|
||||
- skills_so_far: list[str]
|
||||
- current_skill: str
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# Build timestamp context
|
||||
timestamp_lines = []
|
||||
for item in timestamps_with_skills:
|
||||
ts = item["timestamp"]
|
||||
skills_so_far = item["skills_so_far"]
|
||||
current_skill = item["current_skill"]
|
||||
|
||||
if skills_so_far:
|
||||
skills_str = ", ".join(f'"{s}"' for s in skills_so_far)
|
||||
else:
|
||||
skills_str = "None (starting)"
|
||||
|
||||
timestamp_lines.append(
|
||||
f"- **Timestamp {ts:.2f}s**: Skills completed: [{skills_str}] | Current skill: \"{current_skill}\""
|
||||
)
|
||||
|
||||
timestamp_context = "\n".join(timestamp_lines)
|
||||
|
||||
return PGEN_PROMPT_TEMPLATE_VIDEO.format(
|
||||
task_description=task_description,
|
||||
timestamp_context=timestamp_context,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Qwen VLM Interface
|
||||
# =============================================================================
|
||||
@@ -191,40 +389,54 @@ class QwenPgen:
|
||||
|
||||
def call_qwen(
|
||||
self,
|
||||
images: list[Image.Image | str],
|
||||
prompt: str,
|
||||
images: list[Image.Image | str] | None = None,
|
||||
prompt: str = "",
|
||||
video: str | Path | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Call Qwen VLM to generate synthetic dialogue for a single request.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images or image paths
|
||||
images: List of PIL Images or image paths (for image mode)
|
||||
prompt: Text prompt for generation
|
||||
video: Path to video file (for video mode)
|
||||
|
||||
Returns:
|
||||
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
|
||||
"""
|
||||
# Use batch method with single item
|
||||
results = self.call_qwen_batch([images], [prompt])
|
||||
results = self.call_qwen_batch(
|
||||
batch_images=[images] if images else [None],
|
||||
batch_prompts=[prompt],
|
||||
batch_videos=[video] if video else [None],
|
||||
)
|
||||
return results[0]
|
||||
|
||||
def call_qwen_batch(
|
||||
self,
|
||||
batch_images: list[list[Image.Image | str]],
|
||||
batch_images: list[list[Image.Image | str] | None],
|
||||
batch_prompts: list[str],
|
||||
batch_videos: list[str | Path | None] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""
|
||||
Call Qwen VLM to generate synthetic dialogue for a batch of requests.
|
||||
|
||||
Args:
|
||||
batch_images: List of image lists, one per request
|
||||
batch_images: List of image lists, one per request (None for video mode)
|
||||
batch_prompts: List of text prompts, one per request
|
||||
batch_videos: List of video paths, one per request (None for image mode)
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance
|
||||
"""
|
||||
if len(batch_images) != len(batch_prompts):
|
||||
raise ValueError(f"Batch size mismatch: {len(batch_images)} image lists vs {len(batch_prompts)} prompts")
|
||||
if batch_videos is None:
|
||||
batch_videos = [None] * len(batch_images)
|
||||
|
||||
if len(batch_images) != len(batch_prompts) or len(batch_images) != len(batch_videos):
|
||||
raise ValueError(
|
||||
f"Batch size mismatch: {len(batch_images)} image lists vs "
|
||||
f"{len(batch_prompts)} prompts vs {len(batch_videos)} videos"
|
||||
)
|
||||
|
||||
batch_size = len(batch_images)
|
||||
if batch_size == 0:
|
||||
@@ -232,14 +444,21 @@ class QwenPgen:
|
||||
|
||||
# Build messages for each item in batch
|
||||
all_messages = []
|
||||
for images, prompt in zip(batch_images, batch_prompts):
|
||||
for images, prompt, video in zip(batch_images, batch_prompts, batch_videos):
|
||||
content = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
content.append({"type": "image", "image": img})
|
||||
else:
|
||||
# PIL Image
|
||||
content.append({"type": "image", "image": img})
|
||||
|
||||
# Add video or images
|
||||
if video is not None:
|
||||
# Video mode
|
||||
content.append({"type": "video", "video": str(video), "fps": 1.0})
|
||||
elif images is not None:
|
||||
# Image mode
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
content.append({"type": "image", "image": img})
|
||||
else:
|
||||
# PIL Image
|
||||
content.append({"type": "image", "image": img})
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
@@ -381,7 +600,7 @@ def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
|
||||
return skills[-1]["name"] if skills else None
|
||||
|
||||
|
||||
def annotate_sample(
|
||||
def annotate_sample_image(
|
||||
pgen: QwenPgen,
|
||||
images: list[Image.Image | str],
|
||||
task_description: str,
|
||||
@@ -389,7 +608,7 @@ def annotate_sample(
|
||||
skill_current: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Generate synthetic dialogue for a single sample.
|
||||
Generate synthetic dialogue for a single sample using images.
|
||||
|
||||
Args:
|
||||
pgen: Qwen model wrapper
|
||||
@@ -401,42 +620,418 @@ def annotate_sample(
|
||||
Returns:
|
||||
Dictionary with generated dialogue
|
||||
"""
|
||||
prompt = construct_prompt(task_description, skill_history, skill_current)
|
||||
result = pgen.call_qwen(images, prompt)
|
||||
prompt = construct_prompt_image(task_description, skill_history, skill_current)
|
||||
result = pgen.call_qwen(images=images, prompt=prompt, video=None)
|
||||
return result
|
||||
|
||||
|
||||
def annotate_samples_batch(
|
||||
def annotate_episode_video(
|
||||
pgen: QwenPgen,
|
||||
batch_images: list[list[Image.Image | str]],
|
||||
batch_task_descriptions: list[str],
|
||||
batch_skill_histories: list[list[str]],
|
||||
batch_skill_currents: list[str],
|
||||
) -> list[dict[str, str]]:
|
||||
video: str | Path,
|
||||
task_description: str,
|
||||
timestamps_with_skills: list[dict],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Generate synthetic dialogue for a batch of samples.
|
||||
Generate synthetic dialogue for an entire episode using video.
|
||||
|
||||
Args:
|
||||
pgen: Qwen model wrapper
|
||||
batch_images: List of image lists, one per sample
|
||||
batch_task_descriptions: List of task descriptions
|
||||
batch_skill_histories: List of skill history lists
|
||||
batch_skill_currents: List of current skills
|
||||
video: Path to episode video file
|
||||
task_description: High-level task description
|
||||
timestamps_with_skills: List of dicts with timestamp, skills_so_far, current_skill
|
||||
|
||||
Returns:
|
||||
List of dictionaries with generated dialogue
|
||||
List of dictionaries with generated dialogue, one per timestamp
|
||||
"""
|
||||
# Construct prompts for entire batch
|
||||
batch_prompts = []
|
||||
for task_desc, skill_hist, skill_curr in zip(
|
||||
batch_task_descriptions, batch_skill_histories, batch_skill_currents
|
||||
):
|
||||
prompt = construct_prompt(task_desc, skill_hist, skill_curr)
|
||||
batch_prompts.append(prompt)
|
||||
# Use batch method with single episode
|
||||
results = annotate_episodes_video_batch(
|
||||
pgen=pgen,
|
||||
batch_videos=[video],
|
||||
batch_task_descriptions=[task_description],
|
||||
batch_timestamps_with_skills=[timestamps_with_skills],
|
||||
)
|
||||
return results[0]
|
||||
|
||||
|
||||
def annotate_episodes_video_batch(
|
||||
pgen: QwenPgen,
|
||||
batch_videos: list[str | Path],
|
||||
batch_task_descriptions: list[str],
|
||||
batch_timestamps_with_skills: list[list[dict]],
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""
|
||||
Generate synthetic dialogue for multiple episodes using videos in batch.
|
||||
|
||||
# Process entire batch in one call
|
||||
results = pgen.call_qwen_batch(batch_images, batch_prompts)
|
||||
return results
|
||||
Args:
|
||||
pgen: Qwen model wrapper
|
||||
batch_videos: List of paths to episode video files
|
||||
batch_task_descriptions: List of high-level task descriptions
|
||||
batch_timestamps_with_skills: List of timestamp lists, one per episode
|
||||
|
||||
Returns:
|
||||
List of result lists, one per episode (each containing dicts with generated dialogue)
|
||||
"""
|
||||
batch_size = len(batch_videos)
|
||||
if batch_size == 0:
|
||||
return []
|
||||
|
||||
# Build messages for each episode
|
||||
all_messages = []
|
||||
for video, task_desc, timestamps_with_skills in zip(
|
||||
batch_videos, batch_task_descriptions, batch_timestamps_with_skills
|
||||
):
|
||||
prompt = construct_prompt_video(task_desc, timestamps_with_skills)
|
||||
|
||||
content = [
|
||||
{"type": "video", "video": str(video), "fps": 1.0},
|
||||
{"type": "text", "text": prompt},
|
||||
]
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
all_messages.append(messages)
|
||||
|
||||
# Process all episodes through Qwen in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = pgen.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = pgen.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
inputs = pgen.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(pgen.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = pgen.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=2048, # Larger for multiple timestamps per episode
|
||||
do_sample=True,
|
||||
temperature=pgen.temperature,
|
||||
)
|
||||
|
||||
responses = pgen.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Parse each response
|
||||
all_results = []
|
||||
for response, timestamps_with_skills in zip(responses, batch_timestamps_with_skills):
|
||||
results = _parse_video_response(response.strip(), timestamps_with_skills)
|
||||
all_results.append(results)
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def _parse_video_response(response: str, timestamps_with_skills: list[dict]) -> list[dict[str, Any]]:
|
||||
"""Parse JSON array response from video mode."""
|
||||
# Extract JSON from response
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
if not isinstance(data, list):
|
||||
# If it's a dict with a list inside
|
||||
if "annotations" in data:
|
||||
data = data["annotations"]
|
||||
elif "results" in data:
|
||||
data = data["results"]
|
||||
else:
|
||||
raise ValueError("Expected JSON array or dict with 'annotations'/'results' key")
|
||||
|
||||
results = []
|
||||
for item in data:
|
||||
results.append({
|
||||
"timestamp": item.get("timestamp", 0.0),
|
||||
"scenario_type": item.get("scenario_type", "specific_object"),
|
||||
"response_type": item.get("response_type", "confirmation"),
|
||||
"user_prompt": item.get("user_prompt", ""),
|
||||
"robot_utterance": item.get("robot_utterance", ""),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON array in response
|
||||
match = re.search(r"\[.*\]", response, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
results = []
|
||||
for item in data:
|
||||
results.append({
|
||||
"timestamp": item.get("timestamp", 0.0),
|
||||
"scenario_type": item.get("scenario_type", "specific_object"),
|
||||
"response_type": item.get("response_type", "confirmation"),
|
||||
"user_prompt": item.get("user_prompt", ""),
|
||||
"robot_utterance": item.get("robot_utterance", ""),
|
||||
})
|
||||
return results
|
||||
|
||||
breakpoint()
|
||||
# Fallback: return empty results for each timestamp
|
||||
print(f"Warning: Could not parse video response: {response[:200]}...")
|
||||
return [
|
||||
{
|
||||
"timestamp": ts["timestamp"],
|
||||
"scenario_type": "specific_object",
|
||||
"response_type": "confirmation",
|
||||
"user_prompt": "",
|
||||
"robot_utterance": "",
|
||||
}
|
||||
for ts in timestamps_with_skills
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
def _generate_synthetic_data_video_mode(
|
||||
dataset: LeRobotDataset,
|
||||
pgen: QwenPgen,
|
||||
skills_metadata: dict,
|
||||
video_key: str,
|
||||
video_extractor: VideoExtractor,
|
||||
console: Console,
|
||||
sample_interval_seconds: float = 1.0,
|
||||
batch_size: int = 1,
|
||||
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
|
||||
"""
|
||||
Generate synthetic dialogue data using video mode with batched VLM calls.
|
||||
|
||||
The VLM sees full episode videos and generates dialogue for multiple
|
||||
timestamps per episode, with cumulative skill history at each timestamp.
|
||||
|
||||
Args:
|
||||
dataset: LeRobot dataset with skill annotations
|
||||
pgen: Qwen model wrapper
|
||||
skills_metadata: Loaded skills.json metadata
|
||||
video_key: Video observation key (e.g., 'observation.images.base')
|
||||
video_extractor: VideoExtractor instance
|
||||
console: Rich console for logging
|
||||
sample_interval_seconds: Sample timestamps at this interval
|
||||
batch_size: Number of episodes to process in each VLM batch call
|
||||
|
||||
Returns:
|
||||
Tuple of (tasks_df, task_indices_array, debug_outputs)
|
||||
"""
|
||||
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
|
||||
episodes = skills_metadata.get("episodes", {})
|
||||
|
||||
# Track unique high-level tasks
|
||||
high_level_tasks = {}
|
||||
task_index_counter = 0
|
||||
|
||||
# Array to store task index for each frame
|
||||
full_dataset_length = len(dataset)
|
||||
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
|
||||
|
||||
debug_outputs = []
|
||||
timestamps_processed = 0
|
||||
|
||||
console.print(f"[cyan]Processing {len(episodes)} episodes in VIDEO MODE with batch_size={batch_size}...[/cyan]")
|
||||
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s[/cyan]")
|
||||
|
||||
# Convert episodes dict to list for batching
|
||||
episode_list = list(episodes.items())
|
||||
|
||||
# Process episodes in batches
|
||||
for batch_start in tqdm(range(0, len(episode_list), batch_size), desc="Processing episode batches"):
|
||||
batch_end = min(batch_start + batch_size, len(episode_list))
|
||||
batch_episodes = episode_list[batch_start:batch_end]
|
||||
|
||||
# Collect data for this batch
|
||||
batch_data = []
|
||||
extracted_videos = []
|
||||
|
||||
for episode_key, episode_data in batch_episodes:
|
||||
episode_idx = int(episode_key)
|
||||
skills = episode_data.get("skills", [])
|
||||
description = episode_data.get("description", coarse_description)
|
||||
|
||||
if not skills:
|
||||
console.print(f"[yellow]Warning: Episode {episode_idx} has no skills[/yellow]")
|
||||
continue
|
||||
|
||||
# Get video path and extract full episode
|
||||
extracted_path = None
|
||||
try:
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(episode_idx, video_key)
|
||||
if not video_path.exists():
|
||||
console.print(f"[yellow]Warning: Video not found for episode {episode_idx}[/yellow]")
|
||||
continue
|
||||
|
||||
# Get episode timestamps
|
||||
ep = dataset.meta.episodes[episode_idx]
|
||||
episode_start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
|
||||
episode_end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
|
||||
duration = episode_end_ts - episode_start_ts
|
||||
|
||||
# Extract FULL episode video
|
||||
extracted_path = video_extractor.extract_episode_video(
|
||||
video_path, episode_start_ts, episode_end_ts, target_fps=1
|
||||
)
|
||||
extracted_videos.append(extracted_path)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Failed to extract video for episode {episode_idx}: {e}[/yellow]")
|
||||
continue
|
||||
|
||||
# Build list of timestamps to sample
|
||||
timestamps_with_skills = []
|
||||
current_time = 0.0
|
||||
|
||||
while current_time <= duration:
|
||||
# Find which skill is active at this timestamp
|
||||
current_skill = None
|
||||
skills_so_far = []
|
||||
|
||||
for skill in skills:
|
||||
if skill["end"] <= current_time:
|
||||
skills_so_far.append(skill["name"])
|
||||
elif skill["start"] <= current_time < skill["end"]:
|
||||
current_skill = skill["name"]
|
||||
break
|
||||
elif current_time >= skill["end"] and skill == skills[-1]:
|
||||
current_skill = skill["name"]
|
||||
break
|
||||
|
||||
if current_skill:
|
||||
timestamps_with_skills.append({
|
||||
"timestamp": current_time,
|
||||
"skills_so_far": skills_so_far.copy(),
|
||||
"current_skill": current_skill,
|
||||
})
|
||||
|
||||
current_time += sample_interval_seconds
|
||||
|
||||
if not timestamps_with_skills:
|
||||
console.print(f"[yellow]Warning: No valid timestamps for episode {episode_idx}[/yellow]")
|
||||
continue
|
||||
|
||||
# Store batch item
|
||||
batch_data.append({
|
||||
"episode_idx": episode_idx,
|
||||
"episode_metadata": ep,
|
||||
"video_path": extracted_path,
|
||||
"task_description": description,
|
||||
"timestamps_with_skills": timestamps_with_skills,
|
||||
"skills": skills,
|
||||
})
|
||||
|
||||
if not batch_data:
|
||||
continue
|
||||
|
||||
# BATCHED VLM CALL for all episodes in batch
|
||||
try:
|
||||
batch_results = annotate_episodes_video_batch(
|
||||
pgen=pgen,
|
||||
batch_videos=[item["video_path"] for item in batch_data],
|
||||
batch_task_descriptions=[item["task_description"] for item in batch_data],
|
||||
batch_timestamps_with_skills=[item["timestamps_with_skills"] for item in batch_data],
|
||||
)
|
||||
|
||||
# Process results for each episode in batch
|
||||
for item, results in zip(batch_data, batch_results):
|
||||
episode_idx = item["episode_idx"]
|
||||
ep = item["episode_metadata"]
|
||||
timestamps_with_skills = item["timestamps_with_skills"]
|
||||
description = item["task_description"]
|
||||
|
||||
timestamps_processed += len(results)
|
||||
|
||||
# Map results back to timestamps and create task indices
|
||||
timestamp_to_result = {}
|
||||
for result in results:
|
||||
ts = result["timestamp"]
|
||||
timestamp_to_result[ts] = result
|
||||
|
||||
# Process each sampled timestamp
|
||||
for ts_info in timestamps_with_skills:
|
||||
ts = ts_info["timestamp"]
|
||||
result = timestamp_to_result.get(ts, {
|
||||
"timestamp": ts,
|
||||
"scenario_type": "specific_object",
|
||||
"response_type": "confirmation",
|
||||
"user_prompt": "",
|
||||
"robot_utterance": "",
|
||||
})
|
||||
|
||||
# Create unique task key
|
||||
task_key = (
|
||||
result["user_prompt"],
|
||||
result["robot_utterance"],
|
||||
ts_info["current_skill"],
|
||||
result["scenario_type"],
|
||||
result["response_type"],
|
||||
)
|
||||
|
||||
# Assign or create task index
|
||||
if task_key not in high_level_tasks:
|
||||
high_level_tasks[task_key] = task_index_counter
|
||||
task_index_counter += 1
|
||||
|
||||
current_task_idx = high_level_tasks[task_key]
|
||||
|
||||
# Find all frames at this timestamp and assign task_idx
|
||||
ep_from = ep["dataset_from_index"]
|
||||
ep_to = ep["dataset_to_index"]
|
||||
|
||||
for frame_idx in range(ep_from, ep_to):
|
||||
frame = dataset[frame_idx]
|
||||
frame_ts = frame["timestamp"].item()
|
||||
|
||||
# Assign to closest sampled timestamp
|
||||
if abs(frame_ts - ts) < sample_interval_seconds / 2:
|
||||
task_indices[frame_idx] = current_task_idx
|
||||
|
||||
# Save for debugging
|
||||
debug_outputs.append({
|
||||
"episode_id": int(episode_idx),
|
||||
"timestamp": float(ts),
|
||||
"skill_current": ts_info["current_skill"],
|
||||
"skills_so_far": ts_info["skills_so_far"],
|
||||
"task_description": description,
|
||||
"video_mode": True,
|
||||
**result,
|
||||
})
|
||||
|
||||
finally:
|
||||
# Clean up extracted videos
|
||||
for extracted_path in extracted_videos:
|
||||
if extracted_path and extracted_path.exists():
|
||||
extracted_path.unlink()
|
||||
|
||||
console.print(f"[green]✓ Processed {timestamps_processed} timestamps across {len(episodes)} episodes[/green]")
|
||||
|
||||
# Create tasks DataFrame
|
||||
tasks_data = []
|
||||
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
|
||||
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
|
||||
tasks_data.append({
|
||||
"task": f"{user_prompt} | {robot_utterance}",
|
||||
"task_index": task_idx,
|
||||
"user_prompt": user_prompt,
|
||||
"robot_utterance": robot_utterance,
|
||||
"skill": skill,
|
||||
"scenario_type": scenario_type,
|
||||
"response_type": response_type,
|
||||
})
|
||||
|
||||
tasks_df = pd.DataFrame(tasks_data).set_index("task")
|
||||
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
|
||||
|
||||
return tasks_df, task_indices, debug_outputs
|
||||
|
||||
|
||||
def generate_synthetic_data(
|
||||
@@ -446,6 +1041,9 @@ def generate_synthetic_data(
|
||||
image_keys: list[str],
|
||||
sample_interval_seconds: float = 1.0,
|
||||
console: Console | None = None,
|
||||
video_mode: bool = False,
|
||||
video_key: str | None = None,
|
||||
video_batch_size: int = 1,
|
||||
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
|
||||
"""
|
||||
Generate synthetic dialogue data for entire dataset.
|
||||
@@ -458,9 +1056,12 @@ def generate_synthetic_data(
|
||||
dataset: LeRobot dataset with skill annotations
|
||||
pgen: Qwen model wrapper
|
||||
skills_metadata: Loaded skills.json metadata
|
||||
image_keys: List of image observation keys to use
|
||||
image_keys: List of image observation keys to use (for image mode)
|
||||
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
|
||||
console: Rich console for logging
|
||||
video_mode: If True, use video clips instead of sampled images
|
||||
video_key: Video observation key for video mode (e.g., 'observation.images.base')
|
||||
video_batch_size: Number of episodes to process in each VLM batch (video mode only)
|
||||
|
||||
Returns:
|
||||
Tuple of (tasks_df, task_indices_array, debug_outputs)
|
||||
@@ -486,6 +1087,27 @@ def generate_synthetic_data(
|
||||
# For debugging - save to JSONL
|
||||
debug_outputs = []
|
||||
|
||||
# Initialize video extractor if in video mode
|
||||
video_extractor = VideoExtractor(console) if video_mode else None
|
||||
|
||||
if video_mode:
|
||||
if video_key is None:
|
||||
raise ValueError("video_key must be provided when video_mode=True")
|
||||
console.print(f"[cyan]Using VIDEO MODE with video key: {video_key}[/cyan]")
|
||||
console.print(f"[cyan]Video batch size: {video_batch_size} episodes per VLM call[/cyan]")
|
||||
# In video mode, process episodes in batches with full videos
|
||||
return _generate_synthetic_data_video_mode(
|
||||
dataset=dataset,
|
||||
pgen=pgen,
|
||||
skills_metadata=skills_metadata,
|
||||
video_key=video_key,
|
||||
video_extractor=video_extractor,
|
||||
console=console,
|
||||
sample_interval_seconds=sample_interval_seconds,
|
||||
batch_size=video_batch_size,
|
||||
)
|
||||
|
||||
# IMAGE MODE (original logic)
|
||||
# Track sampling
|
||||
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
|
||||
last_task_index = {} # episode_idx -> last generated task_index
|
||||
@@ -566,7 +1188,7 @@ def generate_synthetic_data(
|
||||
continue
|
||||
|
||||
# Generate synthetic dialogue
|
||||
result = annotate_sample(
|
||||
result = annotate_sample_image(
|
||||
pgen=pgen,
|
||||
images=images,
|
||||
task_description=description,
|
||||
@@ -677,11 +1299,18 @@ def main():
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=textwrap.dedent("""\
|
||||
Examples:
|
||||
# Generate synthetic data for a dataset
|
||||
# Generate synthetic data for a dataset (image mode)
|
||||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \\
|
||||
--output-dir ./output
|
||||
|
||||
# Use video mode with batching (passes full episode videos)
|
||||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \\
|
||||
--video-mode \\
|
||||
--video-key observation.images.base \\
|
||||
--video-batch-size 4
|
||||
|
||||
# Use Qwen3 model with custom parameters
|
||||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
|
||||
@@ -742,6 +1371,24 @@ def main():
|
||||
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
|
||||
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-mode",
|
||||
action="store_true",
|
||||
help="Use video input instead of sampled image frames. Passes entire skill video clips to the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Video observation key for video mode (e.g., 'observation.images.base'). "
|
||||
"If not specified, uses the first available video key.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of episodes to process in each VLM batch call in video mode (default: 1)",
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
@@ -795,9 +1442,28 @@ def main():
|
||||
temperature=args.temperature,
|
||||
)
|
||||
|
||||
# Get image keys
|
||||
# Get image keys (for image mode)
|
||||
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
||||
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
||||
if not args.video_mode:
|
||||
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
||||
|
||||
# Determine video key for video mode
|
||||
video_key = None
|
||||
if args.video_mode:
|
||||
if args.video_key:
|
||||
# Use explicitly provided video key
|
||||
video_key = args.video_key
|
||||
if video_key not in dataset.meta.video_keys:
|
||||
console.print(f"[red]Error: Video key '{video_key}' not found in dataset.[/red]")
|
||||
console.print(f"[yellow]Available video keys: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
||||
return
|
||||
elif dataset.meta.video_keys:
|
||||
# Use first available video key
|
||||
video_key = dataset.meta.video_keys[0]
|
||||
else:
|
||||
console.print("[red]Error: No video keys found in dataset. Cannot use video mode.[/red]")
|
||||
return
|
||||
console.print(f"[cyan]Using video key for video mode: {video_key}[/cyan]")
|
||||
|
||||
# Generate synthetic data
|
||||
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
|
||||
@@ -807,6 +1473,9 @@ def main():
|
||||
image_keys=image_keys,
|
||||
sample_interval_seconds=args.sample_interval,
|
||||
console=console,
|
||||
video_mode=args.video_mode,
|
||||
video_key=video_key,
|
||||
video_batch_size=args.video_batch_size,
|
||||
)
|
||||
|
||||
# Save high-level tasks
|
||||
@@ -830,7 +1499,6 @@ def main():
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
breakpoint()
|
||||
new_dataset = add_features(
|
||||
dataset=dataset,
|
||||
features={
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
print(batch.keys())
|
||||
print(batch['task_index_high_level'].shape)
|
||||
print(batch['task_index_high_level'])
|
||||
print(batch['user_prompt'][0])
|
||||
print(batch['robot_utterance'][0])
|
||||
print(batch['task'][0])
|
||||
breakpoint()
|
||||
@@ -9,7 +9,7 @@ MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations"
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
|
||||
@@ -20,6 +20,7 @@ python examples/dataset/annotate_pgen.py \
|
||||
--model "$MODEL" \
|
||||
--output-dir "$OUTPUT_DIR" \
|
||||
--temperature "$TEMPERATURE" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--sample-interval "$SAMPLE_INTERVAL" \
|
||||
--num-image-views-per-sample 1
|
||||
|
||||
@@ -29,3 +30,13 @@ python examples/dataset/annotate_pgen.py \
|
||||
# To push to hub after generation:
|
||||
# Add --push-to-hub flag
|
||||
|
||||
# Efficient batch processing: 4 episodes at once
|
||||
# python examples/dataset/annotate_pgen.py \
|
||||
# --repo-id "$REPO_ID" \
|
||||
# --model "$MODEL" \
|
||||
# --output-dir "$OUTPUT_DIR" \
|
||||
# --video-mode \
|
||||
# --video-key observation.images.up \
|
||||
# --video-batch-size "$BATCH_SIZE" \
|
||||
# --sample-interval 1.0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user