mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
757 lines
27 KiB
Python
757 lines
27 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""
|
||
Synthetic Data Generation for Hi-Robot Style Hierarchical Policy Training.
|
||
|
||
This script generates synthetic user prompts (ℓ_t) and robot utterances (u_t) for
|
||
hierarchical policy training using Qwen VLM as the generator model (pgen).
|
||
|
||
The pipeline:
|
||
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
|
||
2. For each frame, generates synthetic dialogue based on:
|
||
- Visual context (images at time t)
|
||
- Current skill being performed
|
||
- History of previous skills
|
||
- High-level task description
|
||
3. Saves results as high-level tasks and updates dataset with task_index_high_level
|
||
|
||
Usage:
|
||
```bash
|
||
python examples/dataset/annotate_pgen.py \
|
||
--repo-id lerobot/svla_so101_pickplace \
|
||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||
--output-dir /path/to/output \
|
||
--batch-size 1
|
||
```
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import re
|
||
import textwrap
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from PIL import Image
|
||
from rich.console import Console
|
||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||
from tqdm import tqdm
|
||
|
||
from lerobot.datasets.dataset_tools import add_features
|
||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||
|
||
|
||
# =============================================================================
|
||
# Prompt Template for pgen
|
||
# =============================================================================
|
||
|
||
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
|
||
# Role
|
||
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
||
|
||
# Task
|
||
You will receive:
|
||
- A list of images showing the current robot scene at time t
|
||
- The high-level task: {task_description}
|
||
- Previous skill steps completed: {skill_history}
|
||
- The next skill to be performed by the robot: {skill_current}
|
||
|
||
# Your Goal
|
||
Generate two things that create a natural human-robot interaction:
|
||
1. **user_prompt**: A natural-sounding user request that logically leads to the robot
|
||
performing the skill "{skill_current}" given the task context and history.
|
||
2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request.
|
||
|
||
# Guidelines
|
||
- The user prompt should be grounded in the visual scene and task context
|
||
- Vary interaction types: direct commands, implicit requests, corrections, constraints
|
||
- Examples of user prompt styles:
|
||
* Direct: "Can you pick up the red brick?"
|
||
* Implicit: "I need something red for the tower"
|
||
* Negative: "Don't pick up the blue one"
|
||
* Constraint: "Make sure to handle it gently"
|
||
* Correction: "Actually, move to the other box instead"
|
||
- Robot responses should be appropriate: confirmations, clarifications, or error handling
|
||
- Use the skill history to ensure continuity (don't repeat past actions)
|
||
- Consider world knowledge (dietary preferences, object properties, etc.)
|
||
|
||
# Scenario Types (choose one that fits):
|
||
- **specific_object**: User specifies exact object/action
|
||
- **negative_task**: User says what NOT to do
|
||
- **situated_correction**: User adjusts based on current state
|
||
- **implicit_request**: User implies need without direct command
|
||
- **constraint_based**: User adds specific constraints
|
||
|
||
# Response Types (choose one that fits):
|
||
- **confirmation**: Simple "OK, I'll do X"
|
||
- **clarification**: "Just to confirm, you want me to..."
|
||
- **acknowledgment**: "Got it, [doing action]"
|
||
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
|
||
|
||
# Output Format
|
||
Respond ONLY with valid JSON:
|
||
{{
|
||
"scenario_type": "one of the types above",
|
||
"response_type": "one of the types above",
|
||
"user_prompt": "natural user request here",
|
||
"robot_utterance": "natural robot response here"
|
||
}}
|
||
|
||
The responses must be grounded in the visual scene, the task, and the skill history.
|
||
Make it sound like a real human-robot interaction.
|
||
""")
|
||
|
||
|
||
def construct_prompt(
|
||
task_description: str,
|
||
skill_history: list[str],
|
||
skill_current: str,
|
||
) -> str:
|
||
"""
|
||
Construct the text prompt for pgen.
|
||
|
||
Args:
|
||
task_description: High-level task description
|
||
skill_history: List of previously completed skills
|
||
skill_current: Current skill to be performed
|
||
|
||
Returns:
|
||
Formatted prompt string
|
||
"""
|
||
# Format skill history nicely
|
||
if skill_history:
|
||
history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context
|
||
if len(skill_history) > 5:
|
||
history_str = f"... {history_str}"
|
||
else:
|
||
history_str = "None (starting the task)"
|
||
|
||
return PGEN_PROMPT_TEMPLATE.format(
|
||
task_description=task_description,
|
||
skill_history=history_str,
|
||
skill_current=skill_current,
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Qwen VLM Interface
|
||
# =============================================================================
|
||
|
||
class QwenPgen:
|
||
"""Qwen VLM wrapper for synthetic dialogue generation."""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str,
|
||
device: str = "cuda",
|
||
torch_dtype: torch.dtype = torch.bfloat16,
|
||
temperature: float = 0.7,
|
||
):
|
||
from qwen_vl_utils import process_vision_info
|
||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||
|
||
self.console = Console()
|
||
self.device = device
|
||
self.model_name = model_name
|
||
self.temperature = temperature
|
||
self.process_vision_info = process_vision_info
|
||
|
||
self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]")
|
||
|
||
# Load model based on name
|
||
if "qwen3" in model_name.lower():
|
||
from transformers import Qwen3VLMoeForConditionalGeneration
|
||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||
)
|
||
else:
|
||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||
)
|
||
|
||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||
|
||
def call_qwen(
|
||
self,
|
||
images: list[Image.Image | str],
|
||
prompt: str,
|
||
) -> dict[str, str]:
|
||
"""
|
||
Call Qwen VLM to generate synthetic dialogue.
|
||
|
||
Args:
|
||
images: List of PIL Images or image paths
|
||
prompt: Text prompt for generation
|
||
|
||
Returns:
|
||
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
|
||
"""
|
||
# Build messages with images and text
|
||
content = []
|
||
for img in images:
|
||
if isinstance(img, str):
|
||
content.append({"type": "image", "image": img})
|
||
else:
|
||
# PIL Image - need to save temporarily or convert
|
||
content.append({"type": "image", "image": img})
|
||
|
||
content.append({"type": "text", "text": prompt})
|
||
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": content,
|
||
}
|
||
]
|
||
|
||
# Process inputs
|
||
text = self.processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||
|
||
inputs = self.processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
).to(self.device)
|
||
|
||
# Generate
|
||
with torch.no_grad():
|
||
generated_ids = self.model.generate(
|
||
**inputs,
|
||
max_new_tokens=512,
|
||
do_sample=True,
|
||
temperature=self.temperature,
|
||
)
|
||
|
||
# Decode response
|
||
response = self.processor.batch_decode(
|
||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||
skip_special_tokens=True,
|
||
)[0].strip()
|
||
|
||
return self._parse_response(response)
|
||
|
||
def _parse_response(self, response: str) -> dict[str, str]:
|
||
"""Parse JSON response from model."""
|
||
# Extract JSON from response
|
||
if "```json" in response:
|
||
response = response.split("```json")[1].split("```")[0]
|
||
elif "```" in response:
|
||
response = response.split("```")[1].split("```")[0]
|
||
|
||
try:
|
||
data = json.loads(response)
|
||
return {
|
||
"scenario_type": data.get("scenario_type", "specific_object"),
|
||
"response_type": data.get("response_type", "confirmation"),
|
||
"user_prompt": data.get("user_prompt", ""),
|
||
"robot_utterance": data.get("robot_utterance", ""),
|
||
}
|
||
except json.JSONDecodeError:
|
||
# Try to find JSON object in response
|
||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||
if match:
|
||
data = json.loads(match.group())
|
||
return {
|
||
"scenario_type": data.get("scenario_type", "specific_object"),
|
||
"response_type": data.get("response_type", "confirmation"),
|
||
"user_prompt": data.get("user_prompt", ""),
|
||
"robot_utterance": data.get("robot_utterance", ""),
|
||
}
|
||
|
||
raise ValueError(f"Could not parse response: {response[:200]}...")
|
||
|
||
|
||
# =============================================================================
|
||
# Annotation Pipeline
|
||
# =============================================================================
|
||
|
||
def load_skills_metadata(dataset_root: Path) -> dict | None:
|
||
"""Load skills.json metadata from annotated dataset."""
|
||
skills_path = dataset_root / "meta" / "skills.json"
|
||
if skills_path.exists():
|
||
with open(skills_path) as f:
|
||
return json.load(f)
|
||
return None
|
||
|
||
|
||
def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
|
||
"""Find which skill covers a given timestamp."""
|
||
for skill in skills:
|
||
if skill["start"] <= timestamp < skill["end"]:
|
||
return skill["name"]
|
||
# Handle last frame
|
||
if timestamp >= skill["end"] and skill == skills[-1]:
|
||
return skill["name"]
|
||
return skills[-1]["name"] if skills else None
|
||
|
||
|
||
def annotate_sample(
|
||
pgen: QwenPgen,
|
||
images: list[Image.Image | str],
|
||
task_description: str,
|
||
skill_history: list[str],
|
||
skill_current: str,
|
||
) -> dict[str, str]:
|
||
"""
|
||
Generate synthetic dialogue for a single sample.
|
||
|
||
Args:
|
||
pgen: Qwen model wrapper
|
||
images: List of images at current timestep
|
||
task_description: High-level task description
|
||
skill_history: Previous skills completed
|
||
skill_current: Current skill being performed
|
||
|
||
Returns:
|
||
Dictionary with generated dialogue
|
||
"""
|
||
prompt = construct_prompt(task_description, skill_history, skill_current)
|
||
result = pgen.call_qwen(images, prompt)
|
||
return result
|
||
|
||
|
||
def generate_synthetic_data(
|
||
dataset: LeRobotDataset,
|
||
pgen: QwenPgen,
|
||
skills_metadata: dict,
|
||
image_keys: list[str],
|
||
sample_interval_seconds: float = 1.0,
|
||
console: Console | None = None,
|
||
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
|
||
"""
|
||
Generate synthetic dialogue data for entire dataset.
|
||
|
||
This function processes ALL frames in the dataset, but only calls the VLM
|
||
at specified intervals (sample_interval_seconds). Frames between samples
|
||
inherit the task_index from the most recent sample.
|
||
|
||
Args:
|
||
dataset: LeRobot dataset with skill annotations
|
||
pgen: Qwen model wrapper
|
||
skills_metadata: Loaded skills.json metadata
|
||
image_keys: List of image observation keys to use
|
||
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
|
||
console: Rich console for logging
|
||
|
||
Returns:
|
||
Tuple of (tasks_df, task_indices_array, debug_outputs)
|
||
- tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.)
|
||
- task_indices_array: Array of task indices for each frame (full dataset length)
|
||
- debug_outputs: List of debug dictionaries (only for sampled frames)
|
||
"""
|
||
if console is None:
|
||
console = Console()
|
||
|
||
# Extract metadata
|
||
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
|
||
episodes = skills_metadata.get("episodes", {})
|
||
|
||
# Track unique high-level tasks
|
||
high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index
|
||
task_index_counter = 0 # Start at 0
|
||
|
||
# Array to store task index for each frame - MUST match full dataset length
|
||
full_dataset_length = len(dataset)
|
||
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
|
||
|
||
# For debugging - save to JSONL
|
||
debug_outputs = []
|
||
|
||
# Track sampling
|
||
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
|
||
last_task_index = {} # episode_idx -> last generated task_index
|
||
frames_sampled = 0
|
||
|
||
console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]")
|
||
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]")
|
||
|
||
# Process each frame in the FULL dataset
|
||
for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"):
|
||
try:
|
||
# Get frame data
|
||
frame = dataset[frame_idx]
|
||
episode_idx = frame["episode_index"].item()
|
||
timestamp = frame["timestamp"].item()
|
||
|
||
# Get episode skills
|
||
episode_key = str(episode_idx)
|
||
if episode_key not in episodes:
|
||
console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]")
|
||
continue
|
||
|
||
episode_data = episodes[episode_key]
|
||
skills = episode_data.get("skills", [])
|
||
description = episode_data.get("description", coarse_description)
|
||
|
||
# Find current skill
|
||
current_skill = get_skill_at_timestamp(skills, timestamp)
|
||
if current_skill is None:
|
||
console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]")
|
||
continue
|
||
|
||
# Determine if we should sample this frame
|
||
should_sample = False
|
||
|
||
# Always sample first frame of an episode
|
||
if episode_idx not in last_sample_timestamp:
|
||
should_sample = True
|
||
last_sample_timestamp[episode_idx] = timestamp
|
||
else:
|
||
# Sample if enough time has passed
|
||
time_since_last = timestamp - last_sample_timestamp[episode_idx]
|
||
if time_since_last >= sample_interval_seconds:
|
||
should_sample = True
|
||
last_sample_timestamp[episode_idx] = timestamp
|
||
|
||
# If not sampling, reuse last task index for this episode
|
||
if not should_sample:
|
||
if episode_idx in last_task_index:
|
||
task_indices[frame_idx] = last_task_index[episode_idx]
|
||
continue
|
||
|
||
# Sample this frame - generate synthetic dialogue
|
||
frames_sampled += 1
|
||
|
||
# Build skill history (all skills before current timestamp)
|
||
skill_history = []
|
||
for skill in skills:
|
||
if skill["end"] <= timestamp:
|
||
skill_history.append(skill["name"])
|
||
|
||
# Load images
|
||
images = []
|
||
for img_key in image_keys:
|
||
if img_key in frame:
|
||
# Frame images are tensors (C, H, W) in [0, 1]
|
||
img_tensor = frame[img_key]
|
||
if len(img_tensor.shape) == 4: # (T, C, H, W)
|
||
img_tensor = img_tensor[-1] # Take last frame
|
||
|
||
# Convert to PIL Image
|
||
img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||
img_pil = Image.fromarray(img_array)
|
||
images.append(img_pil)
|
||
|
||
if not images:
|
||
console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]")
|
||
continue
|
||
|
||
# Generate synthetic dialogue
|
||
result = annotate_sample(
|
||
pgen=pgen,
|
||
images=images,
|
||
task_description=description,
|
||
skill_history=skill_history,
|
||
skill_current=current_skill,
|
||
)
|
||
|
||
# Create unique task key
|
||
task_key = (
|
||
result["user_prompt"],
|
||
result["robot_utterance"],
|
||
current_skill,
|
||
result["scenario_type"],
|
||
result["response_type"],
|
||
)
|
||
|
||
# Assign or create task index
|
||
if task_key not in high_level_tasks:
|
||
high_level_tasks[task_key] = task_index_counter
|
||
task_index_counter += 1
|
||
|
||
current_task_idx = high_level_tasks[task_key]
|
||
task_indices[frame_idx] = current_task_idx
|
||
last_task_index[episode_idx] = current_task_idx
|
||
|
||
# Save for debugging
|
||
debug_outputs.append({
|
||
"episode_id": int(episode_idx),
|
||
"frame_index": frame_idx,
|
||
"timestamp": float(timestamp),
|
||
"skill_current": current_skill,
|
||
"skill_history": skill_history,
|
||
"task_description": description,
|
||
"sampled": True,
|
||
**result,
|
||
})
|
||
|
||
except Exception as e:
|
||
console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]")
|
||
continue
|
||
|
||
console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]")
|
||
|
||
# Create tasks DataFrame
|
||
tasks_data = []
|
||
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
|
||
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
|
||
tasks_data.append({
|
||
"task": f"{user_prompt} | {robot_utterance}",
|
||
"task_index": task_idx,
|
||
"user_prompt": user_prompt,
|
||
"robot_utterance": robot_utterance,
|
||
"skill": skill,
|
||
"scenario_type": scenario_type,
|
||
"response_type": response_type,
|
||
})
|
||
|
||
tasks_df = pd.DataFrame(tasks_data).set_index("task")
|
||
|
||
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
|
||
|
||
return tasks_df, task_indices, debug_outputs
|
||
|
||
|
||
def save_high_level_tasks(
|
||
tasks_df: pd.DataFrame,
|
||
dataset_root: Path,
|
||
console: Console | None = None,
|
||
) -> None:
|
||
"""Save high-level tasks to tasks_high_level.parquet."""
|
||
if console is None:
|
||
console = Console()
|
||
|
||
output_path = dataset_root / "meta" / "tasks_high_level.parquet"
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
|
||
console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]")
|
||
|
||
|
||
def save_debug_outputs(
|
||
debug_outputs: list[dict],
|
||
dataset_root: Path,
|
||
console: Console | None = None,
|
||
) -> None:
|
||
"""Save debug outputs to JSONL file."""
|
||
if console is None:
|
||
console = Console()
|
||
|
||
output_path = dataset_root / "meta" / "syn_annotations.jsonl"
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(output_path, "w") as f:
|
||
for item in debug_outputs:
|
||
f.write(json.dumps(item) + "\n")
|
||
|
||
console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]")
|
||
|
||
|
||
# =============================================================================
|
||
# Main Entry Point
|
||
# =============================================================================
|
||
|
||
def main():
|
||
"""Main entry point for synthetic data generation."""
|
||
parser = argparse.ArgumentParser(
|
||
description="Generate synthetic dialogue data for hierarchical robot policies",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog=textwrap.dedent("""\
|
||
Examples:
|
||
# Generate synthetic data for a dataset
|
||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||
--model Qwen/Qwen2-VL-7B-Instruct \\
|
||
--output-dir ./output
|
||
|
||
# Use Qwen3 model with custom parameters
|
||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
|
||
--temperature 0.8 \\
|
||
--batch-size 1
|
||
"""),
|
||
)
|
||
|
||
# Data source
|
||
data_group = parser.add_mutually_exclusive_group(required=True)
|
||
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
|
||
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
|
||
|
||
# Model configuration
|
||
parser.add_argument(
|
||
"--model",
|
||
type=str,
|
||
default="Qwen/Qwen2-VL-7B-Instruct",
|
||
help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cuda",
|
||
help="Device to run model on (default: cuda)",
|
||
)
|
||
parser.add_argument(
|
||
"--dtype",
|
||
type=str,
|
||
default="bfloat16",
|
||
choices=["bfloat16", "float16", "float32"],
|
||
help="Model dtype (default: bfloat16)",
|
||
)
|
||
parser.add_argument(
|
||
"--temperature",
|
||
type=float,
|
||
default=0.7,
|
||
help="Sampling temperature (default: 0.7)",
|
||
)
|
||
|
||
# Processing options
|
||
parser.add_argument(
|
||
"--batch-size",
|
||
type=int,
|
||
default=1,
|
||
help="Batch size for processing (default: 1) [currently unused]",
|
||
)
|
||
parser.add_argument(
|
||
"--num-image-views-per-sample",
|
||
type=int,
|
||
default=1,
|
||
help="Number of camera views to use per sample (default: 1)",
|
||
)
|
||
parser.add_argument(
|
||
"--sample-interval",
|
||
type=float,
|
||
default=1.0,
|
||
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
|
||
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
|
||
)
|
||
|
||
# Output options
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
type=str,
|
||
default=None,
|
||
help="Output directory for modified dataset",
|
||
)
|
||
parser.add_argument(
|
||
"--push-to-hub",
|
||
action="store_true",
|
||
help="Push modified dataset to HuggingFace Hub",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
console = Console()
|
||
|
||
# Load dataset
|
||
console.print("[cyan]Loading dataset...[/cyan]")
|
||
if args.data_dir:
|
||
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir)
|
||
dataset_root = Path(args.data_dir)
|
||
else:
|
||
dataset = LeRobotDataset(repo_id=args.repo_id)
|
||
dataset_root = dataset.root
|
||
|
||
console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]")
|
||
|
||
# Load skills metadata
|
||
console.print("[cyan]Loading skills metadata...[/cyan]")
|
||
skills_metadata = load_skills_metadata(dataset_root)
|
||
if skills_metadata is None:
|
||
console.print("[red]Error: No skills.json found. Run annotate.py first![/red]")
|
||
return
|
||
|
||
console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]")
|
||
|
||
# Initialize model
|
||
dtype_map = {
|
||
"bfloat16": torch.bfloat16,
|
||
"float16": torch.float16,
|
||
"float32": torch.float32,
|
||
}
|
||
torch_dtype = dtype_map[args.dtype]
|
||
|
||
console.print(f"[cyan]Initializing {args.model}...[/cyan]")
|
||
pgen = QwenPgen(
|
||
model_name=args.model,
|
||
device=args.device,
|
||
torch_dtype=torch_dtype,
|
||
temperature=args.temperature,
|
||
)
|
||
|
||
# Get image keys
|
||
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
||
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
||
|
||
# Generate synthetic data
|
||
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
|
||
dataset=dataset,
|
||
pgen=pgen,
|
||
skills_metadata=skills_metadata,
|
||
image_keys=image_keys,
|
||
sample_interval_seconds=args.sample_interval,
|
||
console=console,
|
||
)
|
||
|
||
# Save high-level tasks
|
||
save_high_level_tasks(tasks_df, dataset_root, console)
|
||
save_debug_outputs(debug_outputs, dataset_root, console)
|
||
|
||
# Add task_index_high_level feature to dataset
|
||
console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]")
|
||
|
||
# Determine output directory
|
||
if args.output_dir:
|
||
output_dir = Path(args.output_dir)
|
||
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
||
else:
|
||
output_dir = None
|
||
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
||
|
||
# Add feature using dataset_tools
|
||
feature_info = {
|
||
"dtype": "int64",
|
||
"shape": (1,),
|
||
"names": None,
|
||
}
|
||
breakpoint()
|
||
new_dataset = add_features(
|
||
dataset=dataset,
|
||
features={
|
||
"task_index_high_level": (task_indices, feature_info),
|
||
},
|
||
output_dir=output_dir,
|
||
repo_id=repo_id,
|
||
)
|
||
|
||
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
|
||
console.print(f" New dataset saved to: {new_dataset.root}")
|
||
console.print(f" Total high-level tasks: {len(tasks_df)}")
|
||
|
||
# Push to hub if requested
|
||
if args.push_to_hub:
|
||
if args.data_dir:
|
||
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
|
||
else:
|
||
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
|
||
try:
|
||
new_dataset.push_to_hub(push_videos=False)
|
||
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
|
||
except Exception as e:
|
||
console.print(f"[red]Push failed: {e}[/red]")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|