diff --git a/examples/dataset/annotate.py b/examples/dataset/annotate.py new file mode 100644 index 000000000..e03c4b961 --- /dev/null +++ b/examples/dataset/annotate.py @@ -0,0 +1,1143 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Automatic Skill Annotation for LeRobot Datasets. + +This script performs automatic subtask/skill labeling for ANY LeRobot dataset using +Vision-Language Models (VLMs). It segments each robot demonstration into short atomic +skills (1-3 seconds each) and updates the dataset's task field. + +The pipeline: +1. Loads a LeRobot dataset (local or from HuggingFace Hub) +2. For each episode, extracts video frames +3. Uses a VLM to identify skill boundaries and labels +4. Updates the dataset's task metadata with skill annotations + +Supported VLMs (modular design allows easy extension): +- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct" +- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct" +- SmolVLM: "HuggingFaceTB/SmolVLM-Instruct" + +Usage: +```bash +python examples/dataset/annotate.py \ + --repo-id your-username/your-dataset \ + --video-key observation.images.base \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --push-to-hub +``` + +Or with a local dataset: +```bash +python examples/dataset/annotate.py \ + --data-dir /path/to/local/dataset \ + --video-key observation.images.base +``` +After running, you can access the skill for any frame via: +```python +dataset = LeRobotDataset(repo_id="your/dataset") +item = dataset[100] +task_idx = item["task_index"] +skill_name = dataset.meta.tasks.iloc[task_idx].name +``` +""" + +import argparse +import json +import re +import subprocess +import tempfile +import textwrap +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import cv2 +import torch +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +# ============================================================================= +# Skill Annotation Data Structures +# ============================================================================= + + +class Skill: + """Represents a single atomic skill/subtask in a demonstration.""" + + def __init__(self, name: str, start: float, end: float): + self.name = name + self.start = start # Start timestamp in seconds + self.end = end # End timestamp in seconds + + def to_dict(self) -> dict: + return {"name": self.name, "start": self.start, "end": self.end} + + @classmethod + def from_dict(cls, data: dict) -> "Skill": + return cls(name=data["name"], start=data["start"], end=data["end"]) + + def __repr__(self) -> str: + return f"Skill(name='{self.name}', start={self.start:.2f}, end={self.end:.2f})" + + +class EpisodeSkills: + """Container for all skills in an episode.""" + + def __init__(self, episode_index: int, description: str, skills: list[Skill]): + self.episode_index = episode_index + self.description = description + self.skills = skills + + def to_dict(self) -> dict: + return { + "episode_index": self.episode_index, + "description": self.description, + "skills": [s.to_dict() for s in self.skills], + } + + +# ============================================================================= +# VLM Interface (Abstract Base Class for Modularity) +# ============================================================================= + + +class BaseVLM(ABC): + """ + Abstract base class for Vision-Language Models. + + To add a new VLM: + 1. Create a subclass of BaseVLM + 2. Implement the `__init__` and `segment_skills` methods + 3. Register it in the VLM_REGISTRY dictionary + """ + + @abstractmethod + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + """Initialize the VLM with model name, device, and dtype.""" + pass + + @abstractmethod + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """ + Segment a video into atomic skills. + + Args: + video_path: Path to the video file + episode_duration: Total duration of the episode in seconds + coarse_goal: Optional high-level task description + + Returns: + List of Skill objects representing atomic manipulation skills + """ + pass + + +def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str: + """Create the prompt for skill segmentation.""" + goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else "" + + return textwrap.dedent(f"""\ + # Role + You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations. + + # Task + {goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should: + - Last approximately 1-3 seconds + - Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper") + - Have precise start and end timestamps + + # Requirements + 1. **Atomic Actions**: Each skill should be a single, indivisible action + 2. **Complete Coverage**: Skills must cover the entire video duration with no gaps + 3. **Boundary Consistency**: The end of one skill equals the start of the next + 4. **Natural Language**: Use clear, descriptive names for each skill + 5. **Timestamps**: Use seconds (float) for all timestamps + + # Analysis Steps + 1. First, describe what you observe in the video chronologically + 2. Identify distinct motion phases and state changes + 3. Determine precise boundaries based on visual state transitions + + # Output Format + After your analysis, output ONLY valid JSON with this exact structure: + + ```json + {{ + "skills": [ + {{"name": "skill description", "start": 0.0, "end": 1.5}}, + {{"name": "another skill", "start": 1.5, "end": 3.2}} + ] + }} + ``` + + The first skill must start at 0.0 and the last skill must end at the video duration. + """) + + +# ============================================================================= +# Qwen2-VL Implementation +# ============================================================================= + + +class Qwen2VL(BaseVLM): + """Qwen2-VL model for skill segmentation.""" + + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + self.console = Console() + self.device = device + self.model_name = model_name + self.process_vision_info = process_vision_info + + self.console.print(f"[cyan]Loading Qwen2-VL model: {model_name}...[/cyan]") + + self.model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """Segment video into skills using Qwen2-VL.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" + + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + )[0].strip() + + return self._parse_skills_response(response) + + def _parse_skills_response(self, response: str) -> list[Skill]: + """Parse the VLM response into Skill objects.""" + # Extract JSON from response + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + skills_data = data.get("skills", data) + if isinstance(skills_data, list): + return [Skill.from_dict(s) for s in skills_data] + except json.JSONDecodeError: + # Try to find JSON object in response + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + skills_data = data.get("skills", []) + return [Skill.from_dict(s) for s in skills_data] + + raise ValueError(f"Could not parse skills from response: {response[:200]}...") + + +# ============================================================================= +# Qwen3-VL Implementation (MoE variant) +# ============================================================================= + + +class Qwen3VL(BaseVLM): + """Qwen3-VL MoE model for skill segmentation.""" + + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + self.console = Console() + self.device = device + self.model_name = model_name + self.process_vision_info = process_vision_info + + self.console.print(f"[cyan]Loading Qwen3-VL model: {model_name}...[/cyan]") + + self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """Segment video into skills using Qwen3-VL.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" + + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + )[0].strip() + + return self._parse_skills_response(response) + + def _parse_skills_response(self, response: str) -> list[Skill]: + """Parse the VLM response into Skill objects.""" + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + skills_data = data.get("skills", data) + if isinstance(skills_data, list): + return [Skill.from_dict(s) for s in skills_data] + except json.JSONDecodeError: + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + skills_data = data.get("skills", []) + return [Skill.from_dict(s) for s in skills_data] + + raise ValueError(f"Could not parse skills from response: {response[:200]}...") + + +# ============================================================================= +# SmolVLM Implementation +# ============================================================================= + + +class SmolVLM(BaseVLM): + """SmolVLM model for skill segmentation (lighter weight alternative).""" + + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + from transformers import AutoModelForVision2Seq, AutoProcessor + + self.console = Console() + self.device = device + self.model_name = model_name + + self.console.print(f"[cyan]Loading SmolVLM model: {model_name}...[/cyan]") + + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + self.model = AutoModelForVision2Seq.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """Segment video into skills using SmolVLM with frame sampling.""" + import PIL.Image + + # SmolVLM works with images, so we sample frames from the video + frames = self._extract_frames(video_path, target_fps=1) + + if not frames: + raise ValueError(f"Could not extract frames from {video_path}") + + prompt = create_skill_segmentation_prompt(coarse_goal) + duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" + + # Create message with sampled frames + content = [{"type": "text", "text": prompt}] + + # Add frames as images (sample up to 8 frames to avoid context overflow) + frame_indices = self._select_frame_indices(len(frames), max_frames=8) + for idx in frame_indices: + frame = frames[idx] + pil_image = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + content.append({"type": "image", "image": pil_image}) + + content.append( + { + "type": "text", + "text": f"These are {len(frame_indices)} sampled frames from a {duration_str} video. Segment into atomic skills.", + } + ) + + messages = [{"role": "user", "content": content}] + + inputs = self.processor( + text=self.processor.apply_chat_template(messages, add_generation_prompt=True), + images=[PIL.Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) for i in frame_indices], + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() + + return self._parse_skills_response(response, episode_duration) + + def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list: + """Extract frames from video at target FPS.""" + cap = cv2.VideoCapture(str(video_path)) + frames = [] + fps = cap.get(cv2.CAP_PROP_FPS) or 30 + frame_interval = int(fps / target_fps) + + frame_count = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if frame_count % frame_interval == 0: + frames.append(frame) + frame_count += 1 + + cap.release() + return frames + + def _select_frame_indices(self, total_frames: int, max_frames: int = 8) -> list[int]: + """Select evenly spaced frame indices.""" + if total_frames <= max_frames: + return list(range(total_frames)) + step = total_frames / max_frames + return [int(i * step) for i in range(max_frames)] + + def _parse_skills_response(self, response: str, episode_duration: float) -> list[Skill]: + """Parse the VLM response into Skill objects.""" + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + skills_data = data.get("skills", data) + if isinstance(skills_data, list): + return [Skill.from_dict(s) for s in skills_data] + except json.JSONDecodeError: + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + skills_data = data.get("skills", []) + return [Skill.from_dict(s) for s in skills_data] + + # Fallback: create a single skill covering the whole episode + self.console.print("[yellow]Warning: Could not parse skills, creating single skill[/yellow]") + return [Skill(name="perform manipulation", start=0.0, end=episode_duration)] + + +# ============================================================================= +# VLM Registry - Add new VLMs here +# ============================================================================= + +VLM_REGISTRY: dict[str, type[BaseVLM]] = { + # Qwen2-VL variants + "Qwen/Qwen2-VL-2B-Instruct": Qwen2VL, + "Qwen/Qwen2-VL-7B-Instruct": Qwen2VL, + "Qwen/Qwen2-VL-72B-Instruct": Qwen2VL, + # Qwen3-VL variants (MoE) + "Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL, + # SmolVLM variants + "HuggingFaceTB/SmolVLM-Instruct": SmolVLM, + "HuggingFaceTB/SmolVLM-256M-Instruct": SmolVLM, + "HuggingFaceTB/SmolVLM-500M-Instruct": SmolVLM, +} + + +def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM: + """ + Factory function to get the appropriate VLM based on model name. + + Args: + model_name: HuggingFace model identifier + device: Device to load model on + torch_dtype: Data type for model weights + + Returns: + Initialized VLM instance + + Raises: + ValueError: If model is not in registry + """ + # Check exact match first + if model_name in VLM_REGISTRY: + return VLM_REGISTRY[model_name](model_name, device, torch_dtype) + + # Check for partial matches (e.g., "qwen2" in model name) + model_lower = model_name.lower() + if "qwen3" in model_lower: + return Qwen3VL(model_name, device, torch_dtype) + elif "qwen2" in model_lower or "qwen-vl" in model_lower: + return Qwen2VL(model_name, device, torch_dtype) + elif "smolvlm" in model_lower: + return SmolVLM(model_name, device, torch_dtype) + + raise ValueError( + f"Unknown model: {model_name}. " + f"Supported models: {list(VLM_REGISTRY.keys())}. " + "Or implement a new VLM class inheriting from BaseVLM." + ) + + +# ============================================================================= +# Video Extraction Utilities +# ============================================================================= + + +class VideoExtractor: + """Utilities for extracting and processing video segments from LeRobot datasets.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def extract_episode_video( + self, + video_path: Path, + start_timestamp: float, + end_timestamp: float, + target_fps: int = 1, + ) -> Path: + """ + Extract a specific episode segment from a concatenated video file. + + Args: + video_path: Path to the source video file + start_timestamp: Start time in seconds + end_timestamp: End time in seconds + target_fps: Target frames per second for output + + Returns: + Path to the extracted temporary video file + """ + tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + tmp_path = Path(tmp_file.name) + tmp_file.close() + + duration = end_timestamp - start_timestamp + + self.console.print( + f"[cyan]Extracting: {start_timestamp:.1f}s - {end_timestamp:.1f}s ({duration:.1f}s)[/cyan]" + ) + + cmd = [ + "ffmpeg", + "-i", + str(video_path), + "-ss", + str(start_timestamp), + "-t", + str(duration), + "-r", + str(target_fps), + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-an", + "-y", + str(tmp_path), + ] + + try: + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"FFmpeg failed: {e}") from e + except FileNotFoundError: + raise RuntimeError("FFmpeg not found. Please install ffmpeg.") + + if not tmp_path.exists() or tmp_path.stat().st_size < 1024: + if tmp_path.exists(): + tmp_path.unlink() + raise RuntimeError("Video extraction produced invalid file") + + return tmp_path + + def get_video_duration(self, video_path: Path) -> float: + """Get duration of a video file in seconds.""" + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 30 + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return frame_count / fps + + +# ============================================================================= +# Skill Annotation Pipeline +# ============================================================================= + + +class SkillAnnotator: + """ + Main class for annotating LeRobot datasets with skill labels. + + This class orchestrates the full annotation pipeline: + 1. Load dataset + 2. Extract video segments for each episode + 3. Run VLM-based skill segmentation + 4. Update dataset task metadata + """ + + def __init__( + self, + vlm: BaseVLM, + video_extractor: VideoExtractor | None = None, + console: Console | None = None, + ): + self.vlm = vlm + self.console = console or Console() + self.video_extractor = video_extractor or VideoExtractor(self.console) + + def annotate_dataset( + self, + dataset: LeRobotDataset, + video_key: str, + episodes: list[int] | None = None, + skip_existing: bool = False, + ) -> dict[int, EpisodeSkills]: + """ + Annotate all episodes in a dataset with skill labels. + + Args: + dataset: LeRobot dataset to annotate + video_key: Key for video observations (e.g., "observation.images.base") + episodes: Specific episode indices to annotate (None = all) + skip_existing: Skip episodes that already have skill annotations + + Returns: + Dictionary mapping episode index to EpisodeSkills + """ + episode_indices = episodes or list(range(dataset.meta.total_episodes)) + annotations: dict[int, EpisodeSkills] = {} + + # Get coarse task description if available + coarse_goal = self._get_coarse_goal(dataset) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=self.console, + ) as progress: + task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices)) + + for ep_idx in episode_indices: + progress.update(task, description=f"Processing episode {ep_idx}...") + + try: + skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) + except Exception as e: + self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]") + + progress.advance(task) + + return annotations + + def _get_coarse_goal(self, dataset: LeRobotDataset) -> str: + """Extract or generate the coarse task description.""" + # Try to get from existing task metadata + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + # Get the first task description + first_task = dataset.meta.tasks.index[0] + if first_task: + return str(first_task) + + return "Perform the demonstrated manipulation task." + + def _annotate_episode( + self, + dataset: LeRobotDataset, + episode_index: int, + video_key: str, + coarse_goal: str, + ) -> list[Skill]: + """Annotate a single episode with skill labels.""" + # Get video path and timestamps for this episode + video_path = dataset.root / dataset.meta.get_video_file_path(episode_index, video_key) + + if not video_path.exists(): + raise FileNotFoundError(f"Video not found: {video_path}") + + # Get episode timestamps from metadata + ep = dataset.meta.episodes[episode_index] + start_ts = float(ep[f"videos/{video_key}/from_timestamp"]) + end_ts = float(ep[f"videos/{video_key}/to_timestamp"]) + duration = end_ts - start_ts + + # Extract episode segment to temporary file + extracted_path = self.video_extractor.extract_episode_video( + video_path, start_ts, end_ts, target_fps=1 + ) + + try: + # Run VLM skill segmentation + skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal) + return skills + finally: + # Clean up temporary file + if extracted_path.exists(): + extracted_path.unlink() + + +# ============================================================================= +# Metadata Writer - Updates per-frame task_index based on skills +# ============================================================================= + + +def get_skill_for_timestamp(skills: list[Skill], timestamp: float) -> Skill | None: + """ + Find which skill covers a given timestamp. + + Args: + skills: List of skills with start/end times + timestamp: Frame timestamp in seconds + + Returns: + The Skill that covers this timestamp, or None if not found + """ + for skill in skills: + if skill.start <= timestamp < skill.end: + return skill + # Handle the last frame (end boundary) + if timestamp >= skill.end and skill == skills[-1]: + return skill + return skills[-1] if skills else None # Fallback to last skill + + +def update_dataset_tasks( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], +) -> dict[str, int]: + """ + Register all unique skill names as new tasks in the dataset. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + + Returns: + Dictionary mapping skill name to task_index + """ + import pandas as pd + + from lerobot.datasets.utils import write_tasks + + console = Console() + + # Collect all unique skill names + all_skill_names: set[str] = set() + for episode_skills in annotations.values(): + for skill in episode_skills.skills: + all_skill_names.add(skill.name) + + console.print(f"[cyan]Found {len(all_skill_names)} unique skills[/cyan]") + + # Build new tasks DataFrame + # Start with existing tasks if any + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + existing_tasks = set(dataset.meta.tasks.index.tolist()) + max_task_idx = dataset.meta.tasks["task_index"].max() + else: + existing_tasks = set() + max_task_idx = -1 + + # Add new skills as tasks + new_tasks = all_skill_names - existing_tasks + if new_tasks: + new_task_data = [] + for i, skill_name in enumerate(sorted(new_tasks)): + new_task_data.append({ + "task": skill_name, + "task_index": max_task_idx + 1 + i, + }) + + new_tasks_df = pd.DataFrame(new_task_data).set_index("task") + + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + dataset.meta.tasks = pd.concat([dataset.meta.tasks, new_tasks_df]) + else: + dataset.meta.tasks = new_tasks_df + + # Write updated tasks to disk + write_tasks(dataset.meta.tasks, dataset.root) + console.print(f"[green]✓ Added {len(new_tasks)} new tasks to tasks.parquet[/green]") + + # Build skill name to task_index mapping + skill_to_task_idx = { + task_name: int(dataset.meta.tasks.loc[task_name, "task_index"]) + for task_name in all_skill_names + } + + return skill_to_task_idx + + +def update_frame_task_indices( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], + skill_to_task_idx: dict[str, int], +) -> None: + """ + Update the task_index for each frame based on skill annotations. + + This reads the data parquet files, updates task_index based on which + skill covers each frame's timestamp, and writes back to disk. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + skill_to_task_idx: Mapping from skill name to task_index + """ + import pandas as pd + + console = Console() + + # Group episodes by their data file (chunk_index, file_index) + episodes_by_file: dict[tuple[int, int], list[int]] = {} + for ep_idx in annotations.keys(): + ep = dataset.meta.episodes[ep_idx] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + key = (chunk_idx, file_idx) + if key not in episodes_by_file: + episodes_by_file[key] = [] + episodes_by_file[key].append(ep_idx) + + # Process each data file + for (chunk_idx, file_idx), episode_indices in episodes_by_file.items(): + data_path = dataset.root / dataset.meta.data_path.format( + chunk_index=chunk_idx, file_index=file_idx + ) + + if not data_path.exists(): + console.print(f"[yellow]Warning: Data file not found: {data_path}[/yellow]") + continue + + # Read the parquet file + df = pd.read_parquet(data_path) + original_task_indices = df["task_index"].copy() + updated_count = 0 + + # Update task_index for each episode in this file + for ep_idx in episode_indices: + if ep_idx not in annotations: + continue + + episode_skills = annotations[ep_idx] + skills = episode_skills.skills + + # Get episode frame range + ep = dataset.meta.episodes[ep_idx] + ep_from = ep["dataset_from_index"] + ep_to = ep["dataset_to_index"] + + # Filter to rows for this episode + episode_mask = (df["index"] >= ep_from) & (df["index"] < ep_to) + episode_rows = df.loc[episode_mask] + + # Update task_index for each frame based on its timestamp + for idx, row in episode_rows.iterrows(): + timestamp = row["timestamp"] + skill = get_skill_for_timestamp(skills, timestamp) + + if skill and skill.name in skill_to_task_idx: + new_task_idx = skill_to_task_idx[skill.name] + if df.at[idx, "task_index"] != new_task_idx: + df.at[idx, "task_index"] = new_task_idx + updated_count += 1 + + # Write back if any changes were made + if updated_count > 0: + df.to_parquet(data_path, engine="pyarrow", compression="snappy", index=False) + console.print( + f"[green]✓ Updated {updated_count} frame task_indices in {data_path.name}[/green]" + ) + + +def save_skill_annotations( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], + output_path: Path | None = None, +) -> None: + """ + Save skill annotations to the dataset, updating both: + 1. The tasks.parquet with new skill names + 2. The per-frame task_index in data parquet files + + This function updates the task field for each frame based on + which skill covers that frame's timestamp. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + output_path: Optional custom output path for the annotations JSON + """ + console = Console() + + if not annotations: + console.print("[yellow]No annotations to save[/yellow]") + return + + # Step 1: Register all unique skills as tasks + console.print("[cyan]Registering skills as tasks...[/cyan]") + skill_to_task_idx = update_dataset_tasks(dataset, annotations) + + # Step 2: Update per-frame task_index in data parquet files + console.print("[cyan]Updating per-frame task indices...[/cyan]") + update_frame_task_indices(dataset, annotations, skill_to_task_idx) + + # Step 3: Also save the raw skill annotations as JSON for reference + skills_data = { + "coarse_description": annotations[next(iter(annotations))].description, + "skill_to_task_index": skill_to_task_idx, + "episodes": {str(ep_idx): ann.to_dict() for ep_idx, ann in annotations.items()}, + } + + skills_path = output_path or (dataset.root / "meta" / "skills.json") + skills_path.parent.mkdir(parents=True, exist_ok=True) + + with open(skills_path, "w") as f: + json.dump(skills_data, f, indent=2) + + console.print(f"[green]✓ Saved skill annotations to {skills_path}[/green]") + + # Reload the dataset's hf_dataset to reflect changes + dataset._lazy_loading = True + + +def load_skill_annotations(dataset_root: Path) -> dict | None: + """Load existing skill annotations from a dataset.""" + skills_path = dataset_root / "meta" / "skills.json" + if skills_path.exists(): + with open(skills_path) as f: + return json.load(f) + return None + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + + +def main(): + """Main entry point for the skill annotation script.""" + parser = argparse.ArgumentParser( + description="Automatic skill annotation for LeRobot datasets using VLMs", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent("""\ + Examples: + # Annotate a HuggingFace Hub dataset + python annotate.py --repo-id user/dataset --video-key observation.images.base + + # Annotate a local dataset + python annotate.py --data-dir /path/to/dataset --video-key observation.images.base + + # Use a specific model + python annotate.py --repo-id user/dataset --video-key observation.images.base \\ + --model Qwen/Qwen2-VL-7B-Instruct + + # Push annotated dataset to Hub + python annotate.py --repo-id user/dataset --video-key observation.images.base --push-to-hub + """), + ) + + # Data source (mutually exclusive) + data_group = parser.add_mutually_exclusive_group(required=True) + data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset") + data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID") + + # Required arguments + parser.add_argument( + "--video-key", + type=str, + required=True, + help="Video observation key (e.g., 'observation.images.base')", + ) + + # Model configuration + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2-VL-7B-Instruct", + help="VLM model to use for skill segmentation (default: Qwen/Qwen2-VL-7B-Instruct)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run model on (default: cuda)", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Model dtype (default: bfloat16)", + ) + + # Episode selection + parser.add_argument( + "--episodes", + type=int, + nargs="+", + help="Specific episode indices to annotate (default: all)", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip episodes that already have annotations", + ) + + # Output options + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push annotated dataset to HuggingFace Hub", + ) + parser.add_argument( + "--output-path", + type=str, + help="Custom output path for annotations JSON", + ) + + args = parser.parse_args() + console = Console() + + # Validate arguments + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map[args.dtype] + + # Load dataset + console.print("[cyan]Loading dataset...[/cyan]") + if args.data_dir: + dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir, download_videos=False) + else: + dataset = LeRobotDataset(repo_id=args.repo_id, download_videos=True) + + console.print(f"[green]✓ Loaded dataset with {dataset.meta.total_episodes} episodes[/green]") + + # Validate video key + if args.video_key not in dataset.meta.video_keys: + available = ", ".join(dataset.meta.video_keys) + console.print(f"[red]Error: Video key '{args.video_key}' not found. Available: {available}[/red]") + return + + # Initialize VLM + console.print(f"[cyan]Initializing VLM: {args.model}...[/cyan]") + vlm = get_vlm(args.model, args.device, torch_dtype) + + # Create annotator and run annotation + annotator = SkillAnnotator(vlm=vlm, console=console) + annotations = annotator.annotate_dataset( + dataset=dataset, + video_key=args.video_key, + episodes=args.episodes, + skip_existing=args.skip_existing, + ) + + # Save annotations + output_path = Path(args.output_path) if args.output_path else None + save_skill_annotations(dataset, annotations, output_path) + + # Summary + total_skills = sum(len(ann.skills) for ann in annotations.values()) + console.print(f"\n[bold green]✓ Annotation complete![/bold green]") + console.print(f" Episodes annotated: {len(annotations)}") + console.print(f" Total skills identified: {total_skills}") + + # Push to hub if requested + if args.push_to_hub: + if args.data_dir: + console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]") + else: + console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]") + try: + dataset.push_to_hub(push_videos=False) + console.print(f"[green]✓ Pushed to {args.repo_id}[/green]") + except Exception as e: + console.print(f"[red]Push failed: {e}[/red]") + + +if __name__ == "__main__": + main() + diff --git a/examples/dataset/run.sh b/examples/dataset/run.sh new file mode 100644 index 000000000..00e9fa344 --- /dev/null +++ b/examples/dataset/run.sh @@ -0,0 +1,5 @@ +python examples/dataset/annotate.py \ + --repo-id lerobot/svla_so101_pickplace \ + --video-key observation.images.side \ + --model HuggingFaceTB/SmolVLM-Instruct \ + \ No newline at end of file diff --git a/tests/envs/__init__.py b/tests/envs/__init__.py new file mode 100644 index 000000000..adfc257a2 --- /dev/null +++ b/tests/envs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/envs/dummy_hub_env.py b/tests/envs/dummy_hub_env.py new file mode 100644 index 000000000..50b6bc2eb --- /dev/null +++ b/tests/envs/dummy_hub_env.py @@ -0,0 +1,55 @@ +# env.py - Upload this to your Hub repository +# Example: huggingface.co/your-username/test-kwargs-env + +import gymnasium as gym +from gymnasium.vector import SyncVectorEnv + + +def make_env( + n_envs=1, + use_async_envs=False, + config_path=None, + config_overrides=None, + **kwargs, +): + """ + Create vectorized CartPole environments with configurable options. + + Args: + n_envs: Number of parallel environments + use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv + config_path: Optional path to a config file (for demonstration) + config_overrides: Optional dict of config overrides + **kwargs: Additional configuration options + + Returns: + dict mapping suite name to task environments + """ + # Merge all config sources for demonstration + config = {} + if config_overrides: + config.update(config_overrides) + config.update(kwargs) + + # Store config in a way the test can verify + # In a real env, you'd use these to configure the simulation + stored_config = { + "config_path": config_path, + "config_overrides": config_overrides, + "extra_kwargs": kwargs, + } + + def _mk(): + env = gym.make("CartPole-v1") + # Attach config to env for test verification + env.hub_config = stored_config + return env + + Vec = gym.vector.AsyncVectorEnv if use_async_envs else SyncVectorEnv + vec_env = Vec([_mk for _ in range(n_envs)]) + + # Also attach to vector env for easy access in tests + vec_env.hub_config = stored_config + + return {"cartpole_suite": {0: vec_env}} +