mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
add step2
This commit is contained in:
@@ -0,0 +1,756 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Synthetic Data Generation for Hi-Robot Style Hierarchical Policy Training.
|
||||
|
||||
This script generates synthetic user prompts (ℓ_t) and robot utterances (u_t) for
|
||||
hierarchical policy training using Qwen VLM as the generator model (pgen).
|
||||
|
||||
The pipeline:
|
||||
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
|
||||
2. For each frame, generates synthetic dialogue based on:
|
||||
- Visual context (images at time t)
|
||||
- Current skill being performed
|
||||
- History of previous skills
|
||||
- High-level task description
|
||||
3. Saves results as high-level tasks and updates dataset with task_index_high_level
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python examples/dataset/annotate_pgen.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--output-dir /path/to/output \
|
||||
--batch-size 1
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from PIL import Image
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt Template for pgen
|
||||
# =============================================================================
|
||||
|
||||
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
|
||||
# Role
|
||||
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
||||
|
||||
# Task
|
||||
You will receive:
|
||||
- A list of images showing the current robot scene at time t
|
||||
- The high-level task: {task_description}
|
||||
- Previous skill steps completed: {skill_history}
|
||||
- The next skill to be performed by the robot: {skill_current}
|
||||
|
||||
# Your Goal
|
||||
Generate two things that create a natural human-robot interaction:
|
||||
1. **user_prompt**: A natural-sounding user request that logically leads to the robot
|
||||
performing the skill "{skill_current}" given the task context and history.
|
||||
2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request.
|
||||
|
||||
# Guidelines
|
||||
- The user prompt should be grounded in the visual scene and task context
|
||||
- Vary interaction types: direct commands, implicit requests, corrections, constraints
|
||||
- Examples of user prompt styles:
|
||||
* Direct: "Can you pick up the red brick?"
|
||||
* Implicit: "I need something red for the tower"
|
||||
* Negative: "Don't pick up the blue one"
|
||||
* Constraint: "Make sure to handle it gently"
|
||||
* Correction: "Actually, move to the other box instead"
|
||||
- Robot responses should be appropriate: confirmations, clarifications, or error handling
|
||||
- Use the skill history to ensure continuity (don't repeat past actions)
|
||||
- Consider world knowledge (dietary preferences, object properties, etc.)
|
||||
|
||||
# Scenario Types (choose one that fits):
|
||||
- **specific_object**: User specifies exact object/action
|
||||
- **negative_task**: User says what NOT to do
|
||||
- **situated_correction**: User adjusts based on current state
|
||||
- **implicit_request**: User implies need without direct command
|
||||
- **constraint_based**: User adds specific constraints
|
||||
|
||||
# Response Types (choose one that fits):
|
||||
- **confirmation**: Simple "OK, I'll do X"
|
||||
- **clarification**: "Just to confirm, you want me to..."
|
||||
- **acknowledgment**: "Got it, [doing action]"
|
||||
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
|
||||
|
||||
# Output Format
|
||||
Respond ONLY with valid JSON:
|
||||
{{
|
||||
"scenario_type": "one of the types above",
|
||||
"response_type": "one of the types above",
|
||||
"user_prompt": "natural user request here",
|
||||
"robot_utterance": "natural robot response here"
|
||||
}}
|
||||
|
||||
The responses must be grounded in the visual scene, the task, and the skill history.
|
||||
Make it sound like a real human-robot interaction.
|
||||
""")
|
||||
|
||||
|
||||
def construct_prompt(
|
||||
task_description: str,
|
||||
skill_history: list[str],
|
||||
skill_current: str,
|
||||
) -> str:
|
||||
"""
|
||||
Construct the text prompt for pgen.
|
||||
|
||||
Args:
|
||||
task_description: High-level task description
|
||||
skill_history: List of previously completed skills
|
||||
skill_current: Current skill to be performed
|
||||
|
||||
Returns:
|
||||
Formatted prompt string
|
||||
"""
|
||||
# Format skill history nicely
|
||||
if skill_history:
|
||||
history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context
|
||||
if len(skill_history) > 5:
|
||||
history_str = f"... {history_str}"
|
||||
else:
|
||||
history_str = "None (starting the task)"
|
||||
|
||||
return PGEN_PROMPT_TEMPLATE.format(
|
||||
task_description=task_description,
|
||||
skill_history=history_str,
|
||||
skill_current=skill_current,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Qwen VLM Interface
|
||||
# =============================================================================
|
||||
|
||||
class QwenPgen:
|
||||
"""Qwen VLM wrapper for synthetic dialogue generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
device: str = "cuda",
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
temperature: float = 0.7,
|
||||
):
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
self.console = Console()
|
||||
self.device = device
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.process_vision_info = process_vision_info
|
||||
|
||||
self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]")
|
||||
|
||||
# Load model based on name
|
||||
if "qwen3" in model_name.lower():
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration
|
||||
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
def call_qwen(
|
||||
self,
|
||||
images: list[Image.Image | str],
|
||||
prompt: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Call Qwen VLM to generate synthetic dialogue.
|
||||
|
||||
Args:
|
||||
images: List of PIL Images or image paths
|
||||
prompt: Text prompt for generation
|
||||
|
||||
Returns:
|
||||
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
|
||||
"""
|
||||
# Build messages with images and text
|
||||
content = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
content.append({"type": "image", "image": img})
|
||||
else:
|
||||
# PIL Image - need to save temporarily or convert
|
||||
content.append({"type": "image", "image": img})
|
||||
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
]
|
||||
|
||||
# Process inputs
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
# Decode response
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
return self._parse_response(response)
|
||||
|
||||
def _parse_response(self, response: str) -> dict[str, str]:
|
||||
"""Parse JSON response from model."""
|
||||
# Extract JSON from response
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
try:
|
||||
data = json.loads(response)
|
||||
return {
|
||||
"scenario_type": data.get("scenario_type", "specific_object"),
|
||||
"response_type": data.get("response_type", "confirmation"),
|
||||
"user_prompt": data.get("user_prompt", ""),
|
||||
"robot_utterance": data.get("robot_utterance", ""),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in response
|
||||
match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if match:
|
||||
data = json.loads(match.group())
|
||||
return {
|
||||
"scenario_type": data.get("scenario_type", "specific_object"),
|
||||
"response_type": data.get("response_type", "confirmation"),
|
||||
"user_prompt": data.get("user_prompt", ""),
|
||||
"robot_utterance": data.get("robot_utterance", ""),
|
||||
}
|
||||
|
||||
raise ValueError(f"Could not parse response: {response[:200]}...")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation Pipeline
|
||||
# =============================================================================
|
||||
|
||||
def load_skills_metadata(dataset_root: Path) -> dict | None:
|
||||
"""Load skills.json metadata from annotated dataset."""
|
||||
skills_path = dataset_root / "meta" / "skills.json"
|
||||
if skills_path.exists():
|
||||
with open(skills_path) as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
|
||||
def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
|
||||
"""Find which skill covers a given timestamp."""
|
||||
for skill in skills:
|
||||
if skill["start"] <= timestamp < skill["end"]:
|
||||
return skill["name"]
|
||||
# Handle last frame
|
||||
if timestamp >= skill["end"] and skill == skills[-1]:
|
||||
return skill["name"]
|
||||
return skills[-1]["name"] if skills else None
|
||||
|
||||
|
||||
def annotate_sample(
|
||||
pgen: QwenPgen,
|
||||
images: list[Image.Image | str],
|
||||
task_description: str,
|
||||
skill_history: list[str],
|
||||
skill_current: str,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Generate synthetic dialogue for a single sample.
|
||||
|
||||
Args:
|
||||
pgen: Qwen model wrapper
|
||||
images: List of images at current timestep
|
||||
task_description: High-level task description
|
||||
skill_history: Previous skills completed
|
||||
skill_current: Current skill being performed
|
||||
|
||||
Returns:
|
||||
Dictionary with generated dialogue
|
||||
"""
|
||||
prompt = construct_prompt(task_description, skill_history, skill_current)
|
||||
result = pgen.call_qwen(images, prompt)
|
||||
return result
|
||||
|
||||
|
||||
def generate_synthetic_data(
|
||||
dataset: LeRobotDataset,
|
||||
pgen: QwenPgen,
|
||||
skills_metadata: dict,
|
||||
image_keys: list[str],
|
||||
sample_interval_seconds: float = 1.0,
|
||||
console: Console | None = None,
|
||||
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
|
||||
"""
|
||||
Generate synthetic dialogue data for entire dataset.
|
||||
|
||||
This function processes ALL frames in the dataset, but only calls the VLM
|
||||
at specified intervals (sample_interval_seconds). Frames between samples
|
||||
inherit the task_index from the most recent sample.
|
||||
|
||||
Args:
|
||||
dataset: LeRobot dataset with skill annotations
|
||||
pgen: Qwen model wrapper
|
||||
skills_metadata: Loaded skills.json metadata
|
||||
image_keys: List of image observation keys to use
|
||||
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
|
||||
console: Rich console for logging
|
||||
|
||||
Returns:
|
||||
Tuple of (tasks_df, task_indices_array, debug_outputs)
|
||||
- tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.)
|
||||
- task_indices_array: Array of task indices for each frame (full dataset length)
|
||||
- debug_outputs: List of debug dictionaries (only for sampled frames)
|
||||
"""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
# Extract metadata
|
||||
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
|
||||
episodes = skills_metadata.get("episodes", {})
|
||||
|
||||
# Track unique high-level tasks
|
||||
high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index
|
||||
task_index_counter = 0 # Start at 0
|
||||
|
||||
# Array to store task index for each frame - MUST match full dataset length
|
||||
full_dataset_length = len(dataset)
|
||||
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
|
||||
|
||||
# For debugging - save to JSONL
|
||||
debug_outputs = []
|
||||
|
||||
# Track sampling
|
||||
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
|
||||
last_task_index = {} # episode_idx -> last generated task_index
|
||||
frames_sampled = 0
|
||||
|
||||
console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]")
|
||||
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]")
|
||||
|
||||
# Process each frame in the FULL dataset
|
||||
for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"):
|
||||
try:
|
||||
# Get frame data
|
||||
frame = dataset[frame_idx]
|
||||
episode_idx = frame["episode_index"].item()
|
||||
timestamp = frame["timestamp"].item()
|
||||
|
||||
# Get episode skills
|
||||
episode_key = str(episode_idx)
|
||||
if episode_key not in episodes:
|
||||
console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]")
|
||||
continue
|
||||
|
||||
episode_data = episodes[episode_key]
|
||||
skills = episode_data.get("skills", [])
|
||||
description = episode_data.get("description", coarse_description)
|
||||
|
||||
# Find current skill
|
||||
current_skill = get_skill_at_timestamp(skills, timestamp)
|
||||
if current_skill is None:
|
||||
console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]")
|
||||
continue
|
||||
|
||||
# Determine if we should sample this frame
|
||||
should_sample = False
|
||||
|
||||
# Always sample first frame of an episode
|
||||
if episode_idx not in last_sample_timestamp:
|
||||
should_sample = True
|
||||
last_sample_timestamp[episode_idx] = timestamp
|
||||
else:
|
||||
# Sample if enough time has passed
|
||||
time_since_last = timestamp - last_sample_timestamp[episode_idx]
|
||||
if time_since_last >= sample_interval_seconds:
|
||||
should_sample = True
|
||||
last_sample_timestamp[episode_idx] = timestamp
|
||||
|
||||
# If not sampling, reuse last task index for this episode
|
||||
if not should_sample:
|
||||
if episode_idx in last_task_index:
|
||||
task_indices[frame_idx] = last_task_index[episode_idx]
|
||||
continue
|
||||
|
||||
# Sample this frame - generate synthetic dialogue
|
||||
frames_sampled += 1
|
||||
|
||||
# Build skill history (all skills before current timestamp)
|
||||
skill_history = []
|
||||
for skill in skills:
|
||||
if skill["end"] <= timestamp:
|
||||
skill_history.append(skill["name"])
|
||||
|
||||
# Load images
|
||||
images = []
|
||||
for img_key in image_keys:
|
||||
if img_key in frame:
|
||||
# Frame images are tensors (C, H, W) in [0, 1]
|
||||
img_tensor = frame[img_key]
|
||||
if len(img_tensor.shape) == 4: # (T, C, H, W)
|
||||
img_tensor = img_tensor[-1] # Take last frame
|
||||
|
||||
# Convert to PIL Image
|
||||
img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
img_pil = Image.fromarray(img_array)
|
||||
images.append(img_pil)
|
||||
|
||||
if not images:
|
||||
console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]")
|
||||
continue
|
||||
|
||||
# Generate synthetic dialogue
|
||||
result = annotate_sample(
|
||||
pgen=pgen,
|
||||
images=images,
|
||||
task_description=description,
|
||||
skill_history=skill_history,
|
||||
skill_current=current_skill,
|
||||
)
|
||||
|
||||
# Create unique task key
|
||||
task_key = (
|
||||
result["user_prompt"],
|
||||
result["robot_utterance"],
|
||||
current_skill,
|
||||
result["scenario_type"],
|
||||
result["response_type"],
|
||||
)
|
||||
|
||||
# Assign or create task index
|
||||
if task_key not in high_level_tasks:
|
||||
high_level_tasks[task_key] = task_index_counter
|
||||
task_index_counter += 1
|
||||
|
||||
current_task_idx = high_level_tasks[task_key]
|
||||
task_indices[frame_idx] = current_task_idx
|
||||
last_task_index[episode_idx] = current_task_idx
|
||||
|
||||
# Save for debugging
|
||||
debug_outputs.append({
|
||||
"episode_id": int(episode_idx),
|
||||
"frame_index": frame_idx,
|
||||
"timestamp": float(timestamp),
|
||||
"skill_current": current_skill,
|
||||
"skill_history": skill_history,
|
||||
"task_description": description,
|
||||
"sampled": True,
|
||||
**result,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]")
|
||||
continue
|
||||
|
||||
console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]")
|
||||
|
||||
# Create tasks DataFrame
|
||||
tasks_data = []
|
||||
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
|
||||
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
|
||||
tasks_data.append({
|
||||
"task": f"{user_prompt} | {robot_utterance}",
|
||||
"task_index": task_idx,
|
||||
"user_prompt": user_prompt,
|
||||
"robot_utterance": robot_utterance,
|
||||
"skill": skill,
|
||||
"scenario_type": scenario_type,
|
||||
"response_type": response_type,
|
||||
})
|
||||
|
||||
tasks_df = pd.DataFrame(tasks_data).set_index("task")
|
||||
|
||||
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
|
||||
|
||||
return tasks_df, task_indices, debug_outputs
|
||||
|
||||
|
||||
def save_high_level_tasks(
|
||||
tasks_df: pd.DataFrame,
|
||||
dataset_root: Path,
|
||||
console: Console | None = None,
|
||||
) -> None:
|
||||
"""Save high-level tasks to tasks_high_level.parquet."""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
output_path = dataset_root / "meta" / "tasks_high_level.parquet"
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
|
||||
console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]")
|
||||
|
||||
|
||||
def save_debug_outputs(
|
||||
debug_outputs: list[dict],
|
||||
dataset_root: Path,
|
||||
console: Console | None = None,
|
||||
) -> None:
|
||||
"""Save debug outputs to JSONL file."""
|
||||
if console is None:
|
||||
console = Console()
|
||||
|
||||
output_path = dataset_root / "meta" / "syn_annotations.jsonl"
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
for item in debug_outputs:
|
||||
f.write(json.dumps(item) + "\n")
|
||||
|
||||
console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Entry Point
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
"""Main entry point for synthetic data generation."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic dialogue data for hierarchical robot policies",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=textwrap.dedent("""\
|
||||
Examples:
|
||||
# Generate synthetic data for a dataset
|
||||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \\
|
||||
--output-dir ./output
|
||||
|
||||
# Use Qwen3 model with custom parameters
|
||||
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
|
||||
--temperature 0.8 \\
|
||||
--batch-size 1
|
||||
"""),
|
||||
)
|
||||
|
||||
# Data source
|
||||
data_group = parser.add_mutually_exclusive_group(required=True)
|
||||
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
|
||||
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen2-VL-7B-Instruct",
|
||||
help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to run model on (default: cuda)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["bfloat16", "float16", "float32"],
|
||||
help="Model dtype (default: bfloat16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature (default: 0.7)",
|
||||
)
|
||||
|
||||
# Processing options
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for processing (default: 1) [currently unused]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-image-views-per-sample",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of camera views to use per sample (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-interval",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
|
||||
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for modified dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push modified dataset to HuggingFace Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
console = Console()
|
||||
|
||||
# Load dataset
|
||||
console.print("[cyan]Loading dataset...[/cyan]")
|
||||
if args.data_dir:
|
||||
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir)
|
||||
dataset_root = Path(args.data_dir)
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id)
|
||||
dataset_root = dataset.root
|
||||
|
||||
console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]")
|
||||
|
||||
# Load skills metadata
|
||||
console.print("[cyan]Loading skills metadata...[/cyan]")
|
||||
skills_metadata = load_skills_metadata(dataset_root)
|
||||
if skills_metadata is None:
|
||||
console.print("[red]Error: No skills.json found. Run annotate.py first![/red]")
|
||||
return
|
||||
|
||||
console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]")
|
||||
|
||||
# Initialize model
|
||||
dtype_map = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map[args.dtype]
|
||||
|
||||
console.print(f"[cyan]Initializing {args.model}...[/cyan]")
|
||||
pgen = QwenPgen(
|
||||
model_name=args.model,
|
||||
device=args.device,
|
||||
torch_dtype=torch_dtype,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
|
||||
# Get image keys
|
||||
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
||||
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
||||
|
||||
# Generate synthetic data
|
||||
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
|
||||
dataset=dataset,
|
||||
pgen=pgen,
|
||||
skills_metadata=skills_metadata,
|
||||
image_keys=image_keys,
|
||||
sample_interval_seconds=args.sample_interval,
|
||||
console=console,
|
||||
)
|
||||
|
||||
# Save high-level tasks
|
||||
save_high_level_tasks(tasks_df, dataset_root, console)
|
||||
save_debug_outputs(debug_outputs, dataset_root, console)
|
||||
|
||||
# Add task_index_high_level feature to dataset
|
||||
console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]")
|
||||
|
||||
# Determine output directory
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
||||
else:
|
||||
output_dir = None
|
||||
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
||||
|
||||
# Add feature using dataset_tools
|
||||
feature_info = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
breakpoint()
|
||||
new_dataset = add_features(
|
||||
dataset=dataset,
|
||||
features={
|
||||
"task_index_high_level": (task_indices, feature_info),
|
||||
},
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
|
||||
console.print(f" New dataset saved to: {new_dataset.root}")
|
||||
console.print(f" Total high-level tasks: {len(tasks_df)}")
|
||||
|
||||
# Push to hub if requested
|
||||
if args.push_to_hub:
|
||||
if args.data_dir:
|
||||
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
|
||||
else:
|
||||
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
|
||||
try:
|
||||
new_dataset.push_to_hub(push_videos=False)
|
||||
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Push failed: {e}[/red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user