diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9cc28b30..93f91cb88 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -58,6 +58,7 @@ from lerobot.datasets.utils import ( load_nested_dataset, load_stats, load_tasks, + load_tasks_high_level, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -162,6 +163,7 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) + self.tasks_high_level = load_tasks_high_level(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -1060,6 +1062,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + + # optionally add high level task index + if "task_index_high_level" in self.features: + high_level_task_idx = item["task_index_high_level"].item() + item["robot_utterance"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["robot_utterance"] + item["user_prompt"] = self.meta.tasks_high_level.iloc[high_level_task_idx]["user_prompt"] return item def __repr__(self): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 234736a75..d3a146fc1 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -61,6 +61,7 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" +DEFAULT_TASKS_HIGH_LEVEL_PATH = "meta/tasks_high_level.parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" @@ -352,6 +353,9 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) return tasks +def load_tasks_high_level(local_dir: Path) -> pandas.DataFrame: + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_HIGH_LEVEL_PATH) + return tasks def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. diff --git a/src/lerobot/policies/pi05_full/README.md b/src/lerobot/policies/pi05_full/README.md new file mode 100644 index 000000000..2ae69d978 --- /dev/null +++ b/src/lerobot/policies/pi05_full/README.md @@ -0,0 +1,49 @@ +# π₀.₅ (pi05) + +This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence. +It is designed as a **Vision-Language-Action model with open-world generalization**. + +--- + +## Model Overview + +| Feature | π₀ | π₀.₅ | +| -------------------- | ------------------------------------------------------ | ----------------------------------------- | +| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning | +| AdaRMS | Not used | Used in action expert | +| Tokenizer Length | 48 tokens | 200 tokens | +| Discrete State Input | False (Uses `state_proj` layer) | True | +| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) | + +--- + +## Citation + +If you use this work, please cite both **OpenPI** and the π₀.₅ paper: + +```bibtex +@misc{openpi2024, + author = {Physical Intelligence Lab}, + title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies}, + year = {2024}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/Physical-Intelligence/openpi}}, + license = {Apache-2.0} +} + +@misc{intelligence2025pi05visionlanguageactionmodelopenworld, + title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization}, + author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky}, + year = {2025}, + eprint = {2504.16054}, + archivePrefix= {arXiv}, + primaryClass = {cs.LG}, + url = {https://arxiv.org/abs/2504.16054}, +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/src/lerobot/policies/pi05_full/__init__.py b/src/lerobot/policies/pi05_full/__init__.py new file mode 100644 index 000000000..4f9a9de4a --- /dev/null +++ b/src/lerobot/policies/pi05_full/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi05 import PI05Config +from .modeling_pi05 import PI05Policy +from .processor_pi05 import make_pi05_pre_post_processors + +__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] diff --git a/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py b/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py new file mode 100644 index 000000000..d00931a4d --- /dev/null +++ b/src/lerobot/policies/pi05_full/annotate/high_level_annotate.py @@ -0,0 +1,1532 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Synthetic Data Generation for Hierarchical Policy Training. + +This script generates synthetic user prompts (ℓ_t) and robot utterances (u_t) for +hierarchical policy training using Qwen VLM as the generator model (pgen). + +The pipeline: +1. Loads a LeRobot dataset with skill annotations (from annotate.py) +2. For each frame, generates synthetic dialogue based on: + - Visual context (images at time t OR video clips in video mode) + - Current skill being performed + - History of previous skills + - High-level task description +3. Saves results as high-level tasks and updates dataset with task_index_high_level + +Modes: +- Image Mode (default): Samples frames at intervals and sends images to the model +- Video Mode (--video-mode): Passes entire skill video clips to the model + +Usage: +```bash +# Image mode (default) +python examples/dataset/annotate_pgen.py \ + --repo-id lerobot/svla_so101_pickplace \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --output-dir /path/to/output + +# Video mode with batch processing +python examples/dataset/annotate_pgen.py \ + --repo-id lerobot/svla_so101_pickplace \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --video-mode \ + --video-key observation.images.base \ + --video-batch-size 4 +``` +""" + +import argparse +import json +import re +import subprocess +import tempfile +import textwrap +from pathlib import Path +from typing import Any + +import cv2 +import numpy as np +import pandas as pd +import torch +from PIL import Image +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from tqdm import tqdm + +from lerobot.datasets.dataset_tools import add_features +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +# Video Extraction Utilities +class VideoExtractor: + """Utilities for extracting and processing video segments from LeRobot datasets.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def extract_episode_video( + self, + video_path: Path, + start_timestamp: float, + end_timestamp: float, + target_fps: int = 1, + ) -> Path: + """ + Extract a specific episode segment from a concatenated video file. + + Args: + video_path: Path to the source video file + start_timestamp: Start time in seconds + end_timestamp: End time in seconds + target_fps: Target frames per second for output + + Returns: + Path to the extracted temporary video file + """ + tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + tmp_path = Path(tmp_file.name) + tmp_file.close() + + duration = end_timestamp - start_timestamp + + cmd = [ + "ffmpeg", + "-i", + str(video_path), + "-ss", + str(start_timestamp), + "-t", + str(duration), + "-r", + str(target_fps), + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-an", + "-y", + str(tmp_path), + ] + + try: + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"FFmpeg failed: {e}") from e + except FileNotFoundError: + raise RuntimeError("FFmpeg not found. Please install ffmpeg.") + + if not tmp_path.exists() or tmp_path.stat().st_size < 1024: + if tmp_path.exists(): + tmp_path.unlink() + raise RuntimeError("Video extraction produced invalid file") + + return tmp_path + + def get_video_duration(self, video_path: Path) -> float: + """Get duration of a video file in seconds.""" + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 30 + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return frame_count / fps + + +# Prompt Template for pgen + +PGEN_PROMPT_TEMPLATE_IMAGE = textwrap.dedent("""\ + # Role + You are a robot-assistant dialogue generator for hierarchical robot policies. + + # Task + You will receive: + - A list of images showing the current robot scene at time t + - The high-level task: {task_description} + - Previous skill steps completed: {skill_history} + - The next skill to be performed by the robot: {skill_current} + + # Your Goal + Generate two things that create a natural human-robot interaction: + 1. **user_prompt**: A natural-sounding user request that logically leads to the robot + performing the skill "{skill_current}" given the task context and history. + 2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request. + + # Guidelines + - The user prompt should be grounded in the visual scene and task context + - Vary interaction types: direct commands, implicit requests, corrections, constraints + - Examples of user prompt styles: + * Direct: "Can you pick up the red brick?" + * Implicit: "I need something red for the tower" + * Negative: "Don't pick up the blue one" + * Constraint: "Make sure to handle it gently" + * Correction: "Actually, move to the other box instead" + - Robot responses should be appropriate: confirmations, clarifications, or error handling + - Use the skill history to ensure continuity (don't repeat past actions) + - Consider world knowledge (dietary preferences, object properties, etc.) + + # Scenario Types (choose one that fits): + - **specific_object**: User specifies exact object/action + - **negative_task**: User says what NOT to do + - **situated_correction**: User adjusts based on current state + - **implicit_request**: User implies need without direct command + - **constraint_based**: User adds specific constraints + + # Response Types (choose one that fits): + - **confirmation**: Simple "OK, I'll do X" + - **clarification**: "Just to confirm, you want me to..." + - **acknowledgment**: "Got it, [doing action]" + - **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]" + + # Output Format + Respond ONLY with valid JSON: + {{ + "scenario_type": "one of the types above", + "response_type": "one of the types above", + "user_prompt": "natural user request here", + "robot_utterance": "natural robot response here" + }} + + The responses must be grounded in the visual scene, the task, and the skill history. + Make it sound like a real human-robot interaction. + """) + +PGEN_PROMPT_TEMPLATE_VIDEO = textwrap.dedent("""\ + # Role + You are a robot-assistant dialogue generator for hierarchical robot policies. + + # Task + You are watching a full robot demonstration video for the task: {task_description} + + For each timestamp below, generate natural human-robot dialogue that would have led to the observed behavior. + At each timestamp, you'll see: + - What skills have been completed so far (cumulative history) + - What skill is currently being executed + + {timestamp_context} + + # Your Goal + For EACH timestamp, generate: + 1. **user_prompt**: A natural user request that would lead to the robot performing the current skill + 2. **robot_utterance**: A natural robot response acknowledging the request + + # Guidelines + - Watch the video from start to each timestamp to understand the context + - Ground prompts in what's visible in the video at that time + - Vary interaction types: direct commands, implicit requests, corrections, constraints + - Examples of user prompt styles: + * Direct: "Can you pick up the red brick?" + * Implicit: "I need something red for the tower" + * Negative: "Don't pick up the blue one" + * Constraint: "Make sure to handle it gently" + * Correction: "Actually, move to the other box instead" + - Robot responses should be appropriate: confirmations, clarifications, or error handling + - Ensure continuity across timestamps (don't contradict earlier dialogue) + - Consider world knowledge (dietary preferences, object properties, etc.) + + # Scenario Types: + - **specific_object**: User specifies exact object/action + - **negative_task**: User says what NOT to do + - **situated_correction**: User adjusts based on current state + - **implicit_request**: User implies need without direct command + - **constraint_based**: User adds specific constraints + + # Response Types: + - **confirmation**: Simple "OK, I'll do X" + - **clarification**: "Just to confirm, you want me to..." + - **acknowledgment**: "Got it, [doing action]" + - **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]" + + # Output Format + Respond ONLY with valid JSON array: + [ + {{ + "timestamp": timestamp_value, + "scenario_type": "one of the types above", + "response_type": "one of the types above", + "user_prompt": "natural user request here", + "robot_utterance": "natural robot response here" + }}, + ... (one entry per timestamp) + ] + + Make it sound like a real human-robot interaction grounded in the video. + """) + + +def construct_prompt_image( + task_description: str, + skill_history: list[str], + skill_current: str, +) -> str: + """ + Construct the text prompt for pgen in image mode. + + Args: + task_description: High-level task description + skill_history: List of previously completed skills + skill_current: Current skill to be performed + + Returns: + Formatted prompt string + """ + # Format skill history nicely + if skill_history: + history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context + if len(skill_history) > 5: + history_str = f"... {history_str}" + else: + history_str = "None (starting the task)" + + return PGEN_PROMPT_TEMPLATE_IMAGE.format( + task_description=task_description, + skill_history=history_str, + skill_current=skill_current, + ) + + +def construct_prompt_video( + task_description: str, + timestamps_with_skills: list[dict], +) -> str: + """ + Construct the text prompt for pgen in video mode. + + Args: + task_description: High-level task description + timestamps_with_skills: List of dicts with keys: + - timestamp: float + - skills_so_far: list[str] + - current_skill: str + + Returns: + Formatted prompt string + """ + # Build timestamp context + timestamp_lines = [] + for item in timestamps_with_skills: + ts = item["timestamp"] + skills_so_far = item["skills_so_far"] + current_skill = item["current_skill"] + + if skills_so_far: + skills_str = ", ".join(f'"{s}"' for s in skills_so_far) + else: + skills_str = "None (starting)" + + timestamp_lines.append( + f"- **Timestamp {ts:.2f}s**: Skills completed: [{skills_str}] | Current skill: \"{current_skill}\"" + ) + + timestamp_context = "\n".join(timestamp_lines) + + return PGEN_PROMPT_TEMPLATE_VIDEO.format( + task_description=task_description, + timestamp_context=timestamp_context, + ) + + +# Qwen VLM Interface + +class QwenPgen: + """Qwen VLM wrapper for synthetic dialogue generation.""" + + def __init__( + self, + model_name: str, + device: str = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + temperature: float = 0.7, + ): + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + self.console = Console() + self.device = device + self.model_name = model_name + self.temperature = temperature + self.process_vision_info = process_vision_info + + self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]") + + # Load model based on name + if "qwen3" in model_name.lower(): + from transformers import Qwen3VLMoeForConditionalGeneration + self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + else: + self.model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def call_qwen( + self, + images: list[Image.Image | str] | None = None, + prompt: str = "", + video: str | Path | None = None, + ) -> dict[str, str]: + """ + Call Qwen VLM to generate synthetic dialogue for a single request. + + Args: + images: List of PIL Images or image paths (for image mode) + prompt: Text prompt for generation + video: Path to video file (for video mode) + + Returns: + Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance + """ + # Use batch method with single item + results = self.call_qwen_batch( + batch_images=[images] if images else [None], + batch_prompts=[prompt], + batch_videos=[video] if video else [None], + ) + return results[0] + + def call_qwen_batch( + self, + batch_images: list[list[Image.Image | str] | None], + batch_prompts: list[str], + batch_videos: list[str | Path | None] | None = None, + ) -> list[dict[str, str]]: + """ + Call Qwen VLM to generate synthetic dialogue for a batch of requests. + + Args: + batch_images: List of image lists, one per request (None for video mode) + batch_prompts: List of text prompts, one per request + batch_videos: List of video paths, one per request (None for image mode) + + Returns: + List of dictionaries, each with keys: scenario_type, response_type, user_prompt, robot_utterance + """ + if batch_videos is None: + batch_videos = [None] * len(batch_images) + + if len(batch_images) != len(batch_prompts) or len(batch_images) != len(batch_videos): + raise ValueError( + f"Batch size mismatch: {len(batch_images)} image lists vs " + f"{len(batch_prompts)} prompts vs {len(batch_videos)} videos" + ) + + batch_size = len(batch_images) + if batch_size == 0: + return [] + + # Build messages for each item in batch + all_messages = [] + for images, prompt, video in zip(batch_images, batch_prompts, batch_videos): + content = [] + + # Add video or images + if video is not None: + # Video mode + content.append({"type": "video", "video": str(video), "fps": 1.0}) + elif images is not None: + # Image mode + for img in images: + if isinstance(img, str): + content.append({"type": "image", "image": img}) + else: + # PIL Image + content.append({"type": "image", "image": img}) + + content.append({"type": "text", "text": prompt}) + + messages = [ + { + "role": "user", + "content": content, + } + ] + all_messages.append(messages) + + # Process all inputs + texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + texts.append(text) + + image_inputs, video_inputs = self.process_vision_info(messages) + all_image_inputs.append(image_inputs) + all_video_inputs.append(video_inputs) + + # Flatten image and video inputs for batch processing + # The processor expects a flat list of images across all batch items + flat_images = [] + for img_list in all_image_inputs: + if img_list is not None: + if isinstance(img_list, list): + flat_images.extend(img_list) + else: + flat_images.append(img_list) + + flat_videos = [] + for vid_list in all_video_inputs: + if vid_list is not None: + if isinstance(vid_list, list): + flat_videos.extend(vid_list) + else: + flat_videos.append(vid_list) + + # Process batch + inputs = self.processor( + text=texts, + images=flat_images if flat_images else None, + videos=flat_videos if flat_videos else None, + padding=True, + return_tensors="pt", + ).to(self.device) + + # Generate + with torch.no_grad(): + generated_ids = self.model.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=self.temperature, + ) + + # Decode responses + responses = self.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse all responses + results = [] + for response in responses: + try: + parsed = self._parse_response(response.strip()) + results.append(parsed) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + # Return empty/default result + results.append({ + "scenario_type": "specific_object", + "response_type": "confirmation", + "user_prompt": "", + "robot_utterance": "", + }) + + return results + + def _parse_response(self, response: str) -> dict[str, str]: + """Parse JSON response from model.""" + # Extract JSON from response + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + return { + "scenario_type": data.get("scenario_type", "specific_object"), + "response_type": data.get("response_type", "confirmation"), + "user_prompt": data.get("user_prompt", ""), + "robot_utterance": data.get("robot_utterance", ""), + } + except json.JSONDecodeError: + # Try to find JSON object in response + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + return { + "scenario_type": data.get("scenario_type", "specific_object"), + "response_type": data.get("response_type", "confirmation"), + "user_prompt": data.get("user_prompt", ""), + "robot_utterance": data.get("robot_utterance", ""), + } + + raise ValueError(f"Could not parse response: {response[:200]}...") + + +# Annotation Pipeline + +def load_skills_metadata(dataset_root: Path) -> dict | None: + """Load skills.json metadata from annotated dataset.""" + skills_path = dataset_root / "meta" / "skills.json" + if skills_path.exists(): + with open(skills_path) as f: + return json.load(f) + return None + + +def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None: + """Find which skill covers a given timestamp.""" + for skill in skills: + if skill["start"] <= timestamp < skill["end"]: + return skill["name"] + # Handle last frame + if timestamp >= skill["end"] and skill == skills[-1]: + return skill["name"] + return skills[-1]["name"] if skills else None + + +def annotate_sample_image( + pgen: QwenPgen, + images: list[Image.Image | str], + task_description: str, + skill_history: list[str], + skill_current: str, +) -> dict[str, str]: + """ + Generate synthetic dialogue for a single sample using images. + + Args: + pgen: Qwen model wrapper + images: List of images at current timestep + task_description: High-level task description + skill_history: Previous skills completed + skill_current: Current skill being performed + + Returns: + Dictionary with generated dialogue + """ + prompt = construct_prompt_image(task_description, skill_history, skill_current) + result = pgen.call_qwen(images=images, prompt=prompt, video=None) + return result + + +def annotate_episode_video( + pgen: QwenPgen, + video: str | Path, + task_description: str, + timestamps_with_skills: list[dict], +) -> list[dict[str, Any]]: + """ + Generate synthetic dialogue for an entire episode using video. + + Args: + pgen: Qwen model wrapper + video: Path to episode video file + task_description: High-level task description + timestamps_with_skills: List of dicts with timestamp, skills_so_far, current_skill + + Returns: + List of dictionaries with generated dialogue, one per timestamp + """ + # Use batch method with single episode + results = annotate_episodes_video_batch( + pgen=pgen, + batch_videos=[video], + batch_task_descriptions=[task_description], + batch_timestamps_with_skills=[timestamps_with_skills], + ) + return results[0] + + +def annotate_episodes_video_batch( + pgen: QwenPgen, + batch_videos: list[str | Path], + batch_task_descriptions: list[str], + batch_timestamps_with_skills: list[list[dict]], +) -> list[list[dict[str, Any]]]: + """ + Generate synthetic dialogue for multiple episodes using videos in batch. + + Args: + pgen: Qwen model wrapper + batch_videos: List of paths to episode video files + batch_task_descriptions: List of high-level task descriptions + batch_timestamps_with_skills: List of timestamp lists, one per episode + + Returns: + List of result lists, one per episode (each containing dicts with generated dialogue) + """ + batch_size = len(batch_videos) + if batch_size == 0: + return [] + + # Build messages for each episode + all_messages = [] + for video, task_desc, timestamps_with_skills in zip( + batch_videos, batch_task_descriptions, batch_timestamps_with_skills + ): + prompt = construct_prompt_video(task_desc, timestamps_with_skills) + + content = [ + {"type": "video", "video": str(video), "fps": 1.0}, + {"type": "text", "text": prompt}, + ] + + messages = [{"role": "user", "content": content}] + all_messages.append(messages) + + # Process all episodes through Qwen in batch + all_texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = pgen.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = pgen.process_vision_info(messages) + all_texts.append(text) + all_image_inputs.extend(image_inputs or []) + all_video_inputs.extend(video_inputs or []) + + inputs = pgen.processor( + text=all_texts, + images=all_image_inputs if all_image_inputs else None, + videos=all_video_inputs if all_video_inputs else None, + padding=True, + return_tensors="pt", + ).to(pgen.device) + + with torch.no_grad(): + generated_ids = pgen.model.generate( + **inputs, + max_new_tokens=2048, # Larger for multiple timestamps per episode + do_sample=True, + temperature=pgen.temperature, + ) + + responses = pgen.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse each response + all_results = [] + for response, timestamps_with_skills in zip(responses, batch_timestamps_with_skills): + results = _parse_video_response(response.strip(), timestamps_with_skills) + all_results.append(results) + + return all_results + + +def _parse_video_response(response: str, timestamps_with_skills: list[dict]) -> list[dict[str, Any]]: + """Parse JSON array response from video mode.""" + # Extract JSON from response + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + if not isinstance(data, list): + # If it's a dict with a list inside + if "annotations" in data: + data = data["annotations"] + elif "results" in data: + data = data["results"] + else: + raise ValueError("Expected JSON array or dict with 'annotations'/'results' key") + + results = [] + for item in data: + results.append({ + "timestamp": item.get("timestamp", 0.0), + "scenario_type": item.get("scenario_type", "specific_object"), + "response_type": item.get("response_type", "confirmation"), + "user_prompt": item.get("user_prompt", ""), + "robot_utterance": item.get("robot_utterance", ""), + }) + + return results + + except json.JSONDecodeError: + # Try to find JSON array in response + match = re.search(r"\[.*\]", response, re.DOTALL) + if match: + data = json.loads(match.group()) + results = [] + for item in data: + results.append({ + "timestamp": item.get("timestamp", 0.0), + "scenario_type": item.get("scenario_type", "specific_object"), + "response_type": item.get("response_type", "confirmation"), + "user_prompt": item.get("user_prompt", ""), + "robot_utterance": item.get("robot_utterance", ""), + }) + return results + + breakpoint() + # Fallback: return empty results for each timestamp + print(f"Warning: Could not parse video response: {response[:200]}...") + return [ + { + "timestamp": ts["timestamp"], + "scenario_type": "specific_object", + "response_type": "confirmation", + "user_prompt": "", + "robot_utterance": "", + } + for ts in timestamps_with_skills + ] + + + + +def _generate_synthetic_data_video_mode( + dataset: LeRobotDataset, + pgen: QwenPgen, + skills_metadata: dict, + video_key: str, + video_extractor: VideoExtractor, + console: Console, + sample_interval_seconds: float = 1.0, + batch_size: int = 1, +) -> tuple[pd.DataFrame, np.ndarray, list[dict]]: + """ + Generate synthetic dialogue data using video mode with batched VLM calls. + + The VLM sees full episode videos and generates dialogue for multiple + timestamps per episode, with cumulative skill history at each timestamp. + + Args: + dataset: LeRobot dataset with skill annotations + pgen: Qwen model wrapper + skills_metadata: Loaded skills.json metadata + video_key: Video observation key (e.g., 'observation.images.base') + video_extractor: VideoExtractor instance + console: Rich console for logging + sample_interval_seconds: Sample timestamps at this interval + batch_size: Number of episodes to process in each VLM batch call + + Returns: + Tuple of (tasks_df, task_indices_array, debug_outputs) + """ + coarse_description = skills_metadata.get("coarse_description", "Complete the task") + episodes = skills_metadata.get("episodes", {}) + + # Track unique high-level tasks + high_level_tasks = {} + task_index_counter = 0 + + # Array to store task index for each frame + full_dataset_length = len(dataset) + task_indices = np.zeros(full_dataset_length, dtype=np.int64) + + debug_outputs = [] + timestamps_processed = 0 + + console.print(f"[cyan]Processing {len(episodes)} episodes in VIDEO MODE with batch_size={batch_size}...[/cyan]") + console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s[/cyan]") + + # Convert episodes dict to list for batching + episode_list = list(episodes.items()) + + # Process episodes in batches + for batch_start in tqdm(range(0, len(episode_list), batch_size), desc="Processing episode batches"): + batch_end = min(batch_start + batch_size, len(episode_list)) + batch_episodes = episode_list[batch_start:batch_end] + + # Collect data for this batch + batch_data = [] + extracted_videos = [] + + for episode_key, episode_data in batch_episodes: + episode_idx = int(episode_key) + skills = episode_data.get("skills", []) + description = episode_data.get("description", coarse_description) + + if not skills: + console.print(f"[yellow]Warning: Episode {episode_idx} has no skills[/yellow]") + continue + + # Get video path and extract full episode + extracted_path = None + try: + video_path = dataset.root / dataset.meta.get_video_file_path(episode_idx, video_key) + if not video_path.exists(): + console.print(f"[yellow]Warning: Video not found for episode {episode_idx}[/yellow]") + continue + + # Get episode timestamps + ep = dataset.meta.episodes[episode_idx] + episode_start_ts = float(ep[f"videos/{video_key}/from_timestamp"]) + episode_end_ts = float(ep[f"videos/{video_key}/to_timestamp"]) + duration = episode_end_ts - episode_start_ts + + # Extract FULL episode video + extracted_path = video_extractor.extract_episode_video( + video_path, episode_start_ts, episode_end_ts, target_fps=1 + ) + extracted_videos.append(extracted_path) + + except Exception as e: + console.print(f"[yellow]Warning: Failed to extract video for episode {episode_idx}: {e}[/yellow]") + continue + + # Build list of timestamps to sample + timestamps_with_skills = [] + current_time = 0.0 + + while current_time <= duration: + # Find which skill is active at this timestamp + current_skill = None + skills_so_far = [] + + for skill in skills: + if skill["end"] <= current_time: + skills_so_far.append(skill["name"]) + elif skill["start"] <= current_time < skill["end"]: + current_skill = skill["name"] + break + elif current_time >= skill["end"] and skill == skills[-1]: + current_skill = skill["name"] + break + + if current_skill: + timestamps_with_skills.append({ + "timestamp": current_time, + "skills_so_far": skills_so_far.copy(), + "current_skill": current_skill, + }) + + current_time += sample_interval_seconds + + if not timestamps_with_skills: + console.print(f"[yellow]Warning: No valid timestamps for episode {episode_idx}[/yellow]") + continue + + # Store batch item + batch_data.append({ + "episode_idx": episode_idx, + "episode_metadata": ep, + "video_path": extracted_path, + "task_description": description, + "timestamps_with_skills": timestamps_with_skills, + "skills": skills, + }) + + if not batch_data: + continue + + # BATCHED VLM CALL for all episodes in batch + try: + batch_results = annotate_episodes_video_batch( + pgen=pgen, + batch_videos=[item["video_path"] for item in batch_data], + batch_task_descriptions=[item["task_description"] for item in batch_data], + batch_timestamps_with_skills=[item["timestamps_with_skills"] for item in batch_data], + ) + + # Process results for each episode in batch + for item, results in zip(batch_data, batch_results): + episode_idx = item["episode_idx"] + ep = item["episode_metadata"] + timestamps_with_skills = item["timestamps_with_skills"] + description = item["task_description"] + + timestamps_processed += len(results) + + # Map results back to timestamps and create task indices + timestamp_to_result = {} + for result in results: + ts = result["timestamp"] + timestamp_to_result[ts] = result + + # Process each sampled timestamp + for ts_info in timestamps_with_skills: + ts = ts_info["timestamp"] + result = timestamp_to_result.get(ts, { + "timestamp": ts, + "scenario_type": "specific_object", + "response_type": "confirmation", + "user_prompt": "", + "robot_utterance": "", + }) + + # Create unique task key + task_key = ( + result["user_prompt"], + result["robot_utterance"], + ts_info["current_skill"], + result["scenario_type"], + result["response_type"], + ) + + # Assign or create task index + if task_key not in high_level_tasks: + high_level_tasks[task_key] = task_index_counter + task_index_counter += 1 + + current_task_idx = high_level_tasks[task_key] + + # Find all frames at this timestamp and assign task_idx + ep_from = ep["dataset_from_index"] + ep_to = ep["dataset_to_index"] + + for frame_idx in range(ep_from, ep_to): + frame = dataset[frame_idx] + frame_ts = frame["timestamp"].item() + + # Assign to closest sampled timestamp + if abs(frame_ts - ts) < sample_interval_seconds / 2: + task_indices[frame_idx] = current_task_idx + + # Save for debugging + debug_outputs.append({ + "episode_id": int(episode_idx), + "timestamp": float(ts), + "skill_current": ts_info["current_skill"], + "skills_so_far": ts_info["skills_so_far"], + "task_description": description, + "video_mode": True, + **result, + }) + + finally: + # Clean up extracted videos + for extracted_path in extracted_videos: + if extracted_path and extracted_path.exists(): + extracted_path.unlink() + + console.print(f"[green]✓ Processed {timestamps_processed} timestamps across {len(episodes)} episodes[/green]") + + # Create tasks DataFrame + tasks_data = [] + for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]): + user_prompt, robot_utterance, skill, scenario_type, response_type = task_key + tasks_data.append({ + "task": f"{user_prompt} | {robot_utterance}", + "task_index": task_idx, + "user_prompt": user_prompt, + "robot_utterance": robot_utterance, + "skill": skill, + "scenario_type": scenario_type, + "response_type": response_type, + }) + + tasks_df = pd.DataFrame(tasks_data).set_index("task") + console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]") + + return tasks_df, task_indices, debug_outputs + + +def generate_synthetic_data( + dataset: LeRobotDataset, + pgen: QwenPgen, + skills_metadata: dict, + image_keys: list[str], + sample_interval_seconds: float = 1.0, + console: Console | None = None, + video_mode: bool = False, + video_key: str | None = None, + video_batch_size: int = 1, +) -> tuple[pd.DataFrame, np.ndarray, list[dict]]: + """ + Generate synthetic dialogue data for entire dataset. + + This function processes ALL frames in the dataset, but only calls the VLM + at specified intervals (sample_interval_seconds). Frames between samples + inherit the task_index from the most recent sample. + + Args: + dataset: LeRobot dataset with skill annotations + pgen: Qwen model wrapper + skills_metadata: Loaded skills.json metadata + image_keys: List of image observation keys to use (for image mode) + sample_interval_seconds: Generate dialogue every N seconds (default: 1.0) + console: Rich console for logging + video_mode: If True, use video clips instead of sampled images + video_key: Video observation key for video mode (e.g., 'observation.images.base') + video_batch_size: Number of episodes to process in each VLM batch (video mode only) + + Returns: + Tuple of (tasks_df, task_indices_array, debug_outputs) + - tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.) + - task_indices_array: Array of task indices for each frame (full dataset length) + - debug_outputs: List of debug dictionaries (only for sampled frames) + """ + if console is None: + console = Console() + + # Extract metadata + coarse_description = skills_metadata.get("coarse_description", "Complete the task") + episodes = skills_metadata.get("episodes", {}) + + # Track unique high-level tasks + high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index + task_index_counter = 0 # Start at 0 + + # Array to store task index for each frame - MUST match full dataset length + full_dataset_length = len(dataset) + task_indices = np.zeros(full_dataset_length, dtype=np.int64) + + # For debugging - save to JSONL + debug_outputs = [] + + # Initialize video extractor if in video mode + video_extractor = VideoExtractor(console) if video_mode else None + + if video_mode: + if video_key is None: + raise ValueError("video_key must be provided when video_mode=True") + console.print(f"[cyan]Using VIDEO MODE with video key: {video_key}[/cyan]") + console.print(f"[cyan]Video batch size: {video_batch_size} episodes per VLM call[/cyan]") + # In video mode, process episodes in batches with full videos + return _generate_synthetic_data_video_mode( + dataset=dataset, + pgen=pgen, + skills_metadata=skills_metadata, + video_key=video_key, + video_extractor=video_extractor, + console=console, + sample_interval_seconds=sample_interval_seconds, + batch_size=video_batch_size, + ) + + # IMAGE MODE (original logic) + # Track sampling + last_sample_timestamp = {} # episode_idx -> last sampled timestamp + last_task_index = {} # episode_idx -> last generated task_index + frames_sampled = 0 + + console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]") + console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]") + + # Process each frame in the FULL dataset + for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"): + try: + # Get frame data + frame = dataset[frame_idx] + episode_idx = frame["episode_index"].item() + timestamp = frame["timestamp"].item() + + # Get episode skills + episode_key = str(episode_idx) + if episode_key not in episodes: + console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]") + continue + + episode_data = episodes[episode_key] + skills = episode_data.get("skills", []) + description = episode_data.get("description", coarse_description) + + # Find current skill + current_skill = get_skill_at_timestamp(skills, timestamp) + if current_skill is None: + console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]") + continue + + # Determine if we should sample this frame + should_sample = False + + # Always sample first frame of an episode + if episode_idx not in last_sample_timestamp: + should_sample = True + last_sample_timestamp[episode_idx] = timestamp + else: + # Sample if enough time has passed + time_since_last = timestamp - last_sample_timestamp[episode_idx] + if time_since_last >= sample_interval_seconds: + should_sample = True + last_sample_timestamp[episode_idx] = timestamp + + # If not sampling, reuse last task index for this episode + if not should_sample: + if episode_idx in last_task_index: + task_indices[frame_idx] = last_task_index[episode_idx] + continue + + # Sample this frame - generate synthetic dialogue + frames_sampled += 1 + + # Build skill history (all skills before current timestamp) + skill_history = [] + for skill in skills: + if skill["end"] <= timestamp: + skill_history.append(skill["name"]) + + # Load images + images = [] + for img_key in image_keys: + if img_key in frame: + # Frame images are tensors (C, H, W) in [0, 1] + img_tensor = frame[img_key] + if len(img_tensor.shape) == 4: # (T, C, H, W) + img_tensor = img_tensor[-1] # Take last frame + + # Convert to PIL Image + img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + img_pil = Image.fromarray(img_array) + images.append(img_pil) + + if not images: + console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]") + continue + + # Generate synthetic dialogue + result = annotate_sample_image( + pgen=pgen, + images=images, + task_description=description, + skill_history=skill_history, + skill_current=current_skill, + ) + + # Create unique task key + task_key = ( + result["user_prompt"], + result["robot_utterance"], + current_skill, + result["scenario_type"], + result["response_type"], + ) + + # Assign or create task index + if task_key not in high_level_tasks: + high_level_tasks[task_key] = task_index_counter + task_index_counter += 1 + + current_task_idx = high_level_tasks[task_key] + task_indices[frame_idx] = current_task_idx + last_task_index[episode_idx] = current_task_idx + + # Save for debugging + debug_outputs.append({ + "episode_id": int(episode_idx), + "frame_index": frame_idx, + "timestamp": float(timestamp), + "skill_current": current_skill, + "skill_history": skill_history, + "task_description": description, + "sampled": True, + **result, + }) + + except Exception as e: + console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]") + continue + + console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]") + + # Create tasks DataFrame + tasks_data = [] + for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]): + user_prompt, robot_utterance, skill, scenario_type, response_type = task_key + tasks_data.append({ + "task": f"{user_prompt} | {robot_utterance}", + "task_index": task_idx, + "user_prompt": user_prompt, + "robot_utterance": robot_utterance, + "skill": skill, + "scenario_type": scenario_type, + "response_type": response_type, + }) + + tasks_df = pd.DataFrame(tasks_data).set_index("task") + + console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]") + + return tasks_df, task_indices, debug_outputs + + +def save_high_level_tasks( + tasks_df: pd.DataFrame, + dataset_root: Path, + console: Console | None = None, +) -> None: + """Save high-level tasks to tasks_high_level.parquet.""" + if console is None: + console = Console() + + output_path = dataset_root / "meta" / "tasks_high_level.parquet" + output_path.parent.mkdir(parents=True, exist_ok=True) + + tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy") + console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]") + + +def save_debug_outputs( + debug_outputs: list[dict], + dataset_root: Path, + console: Console | None = None, +) -> None: + """Save debug outputs to JSONL file.""" + if console is None: + console = Console() + + output_path = dataset_root / "meta" / "syn_annotations.jsonl" + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + for item in debug_outputs: + f.write(json.dumps(item) + "\n") + + console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]") + + +# main entry point + +def main(): + """Main entry point for synthetic data generation.""" + parser = argparse.ArgumentParser( + description="Generate synthetic dialogue data for hierarchical robot policies", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent("""\ + Examples: + # Generate synthetic data for a dataset (image mode) + python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\ + --model Qwen/Qwen2-VL-7B-Instruct \\ + --output-dir ./output + + # Use video mode with batching (passes full episode videos) + python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\ + --model Qwen/Qwen2-VL-7B-Instruct \\ + --video-mode \\ + --video-key observation.images.base \\ + --video-batch-size 4 + + # Use Qwen3 model with custom parameters + python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\ + --model Qwen/Qwen3-VL-30B-A3B-Instruct \\ + --temperature 0.8 \\ + --batch-size 1 + """), + ) + + # Data source + data_group = parser.add_mutually_exclusive_group(required=True) + data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset") + data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID") + + # Model configuration + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2-VL-7B-Instruct", + help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run model on (default: cuda)", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Model dtype (default: bfloat16)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature (default: 0.7)", + ) + + # Processing options + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size for processing (default: 1) [currently unused]", + ) + parser.add_argument( + "--num-image-views-per-sample", + type=int, + default=1, + help="Number of camera views to use per sample (default: 1)", + ) + parser.add_argument( + "--sample-interval", + type=float, + default=1.0, + help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. " + "Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.", + ) + parser.add_argument( + "--video-mode", + action="store_true", + help="Use video input instead of sampled image frames. Passes entire skill video clips to the model.", + ) + parser.add_argument( + "--video-key", + type=str, + default=None, + help="Video observation key for video mode (e.g., 'observation.images.base'). " + "If not specified, uses the first available video key.", + ) + parser.add_argument( + "--video-batch-size", + type=int, + default=1, + help="Number of episodes to process in each VLM batch call in video mode (default: 1)", + ) + + # Output options + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory for modified dataset", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push modified dataset to HuggingFace Hub", + ) + # add image key + parser.add_argument( + "--image-key", + type=str, + default=None, + help="Image observation key to use for image mode (default: None)", + ) + + args = parser.parse_args() + console = Console() + + # Load dataset + console.print("[cyan]Loading dataset...[/cyan]") + if args.data_dir: + dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir) + dataset_root = Path(args.data_dir) + else: + dataset = LeRobotDataset(repo_id=args.repo_id) + dataset_root = dataset.root + + console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]") + + # Load skills metadata + console.print("[cyan]Loading skills metadata...[/cyan]") + skills_metadata = load_skills_metadata(dataset_root) + if skills_metadata is None: + console.print("[red]Error: No skills.json found. Run annotate.py first![/red]") + return + + console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]") + + # Initialize model + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map[args.dtype] + + console.print(f"[cyan]Initializing {args.model}...[/cyan]") + pgen = QwenPgen( + model_name=args.model, + device=args.device, + torch_dtype=torch_dtype, + temperature=args.temperature, + ) + + # Get image keys (for image mode) + if args.image_key: + image_keys = [args.image_key] + else: + image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample] + if not args.video_mode: + console.print(f"[cyan]Using image keys: {image_keys}[/cyan]") + + # Determine video key for video mode + video_key = None + if args.video_mode: + if args.video_key: + # Use explicitly provided video key + video_key = args.video_key + if video_key not in dataset.meta.video_keys: + console.print(f"[red]Error: Video key '{video_key}' not found in dataset.[/red]") + console.print(f"[yellow]Available video keys: {', '.join(dataset.meta.video_keys)}[/yellow]") + return + elif dataset.meta.video_keys: + # Use first available video key + video_key = dataset.meta.video_keys[0] + else: + console.print("[red]Error: No video keys found in dataset. Cannot use video mode.[/red]") + return + console.print(f"[cyan]Using video key for video mode: {video_key}[/cyan]") + + # Generate synthetic data + tasks_df, task_indices, debug_outputs = generate_synthetic_data( + dataset=dataset, + pgen=pgen, + skills_metadata=skills_metadata, + image_keys=image_keys, + sample_interval_seconds=args.sample_interval, + console=console, + video_mode=args.video_mode, + video_key=video_key, + video_batch_size=args.video_batch_size, + ) + + # Save high-level tasks + save_high_level_tasks(tasks_df, dataset_root, console) + save_debug_outputs(debug_outputs, dataset_root, console) + + # Add task_index_high_level feature to dataset + console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]") + + # Determine output directory + if args.output_dir: + output_dir = Path(args.output_dir) + repo_id = f"{dataset.repo_id}_with_high_level_tasks" + else: + output_dir = None + repo_id = f"{dataset.repo_id}_with_high_level_tasks" + + # Add feature using dataset_tools + feature_info = { + "dtype": "int64", + "shape": (1,), + "names": None, + } + new_dataset = add_features( + dataset=dataset, + features={ + "task_index_high_level": (task_indices, feature_info), + }, + output_dir=output_dir, + repo_id=repo_id, + ) + + # copy high level tsk parquet to new output directory + import shutil + shutil.copy(dataset_root / "meta" / "tasks_high_level.parquet", output_dir / "meta" / "tasks_high_level.parquet") + shutil.copy(dataset_root / "meta" / "syn_annotations.jsonl", output_dir / "meta" / "syn_annotations.jsonl") + + console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]") + console.print(f" New dataset saved to: {new_dataset.root}") + console.print(f" Total high-level tasks: {len(tasks_df)}") + + # Push to hub if requested + if args.push_to_hub: + if args.data_dir: + console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]") + else: + console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]") + try: + new_dataset.push_to_hub() + console.print(f"[green]✓ Pushed to {repo_id}[/green]") + except Exception as e: + console.print(f"[red]Push failed: {e}[/red]") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py new file mode 100644 index 000000000..600953ce6 --- /dev/null +++ b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py @@ -0,0 +1,23 @@ +import torch +from huggingface_hub import HfApi + +import lerobot +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata + +dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1") + +dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=2, + shuffle=True, +) + +batch = next(iter(dataloader)) +print(batch.keys()) +print(batch['task_index_high_level'].shape) +print(batch['task_index_high_level']) +print(batch['user_prompt'][0]) +print(batch['robot_utterance'][0]) +print(batch['task'][0]) +breakpoint() \ No newline at end of file diff --git a/src/lerobot/policies/pi05_full/annotate/run_pgen.sh b/src/lerobot/policies/pi05_full/annotate/run_pgen.sh new file mode 100644 index 000000000..570a67799 --- /dev/null +++ b/src/lerobot/policies/pi05_full/annotate/run_pgen.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Example script to run synthetic data generation with Qwen VLM +# This generates user prompts and robot utterances for hierarchical policy training + +# Configuration +REPO_ID="jadechoghari/collect-data" +MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct" +# or: MODEL="Qwen/Qwen2-VL-7B-Instruct" + + +OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen" +BATCH_SIZE=32 +TEMPERATURE=0.9 +SAMPLE_INTERVAL=5.0 # generate dialogue every 1 second (all episodes processed) + +# run synthetic data generation (all episodes processed) +python examples/dataset/annotate_pgen.py \ + --repo-id "$REPO_ID" \ + --model "$MODEL" \ + --output-dir "$OUTPUT_DIR" \ + --temperature "$TEMPERATURE" \ + --batch-size "$BATCH_SIZE" \ + --sample-interval "$SAMPLE_INTERVAL" \ + --image-key observation.images.base \ + --num-image-views-per-sample 1 + +# for faster testing, increase sample interval: +# --sample-interval 5.0 # Samples every 5 seconds (much faster) + +# to push to hub after generation: +# add --push-to-hub flag + +# efficient batch processing: 4 episodes at once +# python examples/dataset/annotate_pgen.py \ +# --repo-id "$REPO_ID" \ +# --model "$MODEL" \ +# --output-dir "$OUTPUT_DIR" \ +# --video-mode \ +# --video-key observation.images.up \ +# --video-batch-size "$BATCH_SIZE" \ +# --sample-interval 1.0 diff --git a/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py b/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py new file mode 100644 index 000000000..da296bd2d --- /dev/null +++ b/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py @@ -0,0 +1,1258 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Automatic Skill Annotation for LeRobot Datasets. + +This script performs automatic subtask/skill labeling for ANY LeRobot dataset using +Vision-Language Models (VLMs). It segments each robot demonstration into short atomic +skills (1-3 seconds each) and updates the dataset's task field. + +The pipeline: +1. Loads a LeRobot dataset (local or from HuggingFace Hub) +2. For each episode, extracts video frames +3. Uses a VLM to identify skill boundaries and labels +4. Updates the dataset's task metadata with skill annotations + +Supported VLMs (modular design allows easy extension): +- Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct" +- Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct" + +Usage: +```bash +python examples/dataset/annotate.py \ + --repo-id your-username/your-dataset \ + --video-key observation.images.base \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --push-to-hub +``` + +Or with a local dataset: +```bash +python examples/dataset/annotate.py \ + --data-dir /path/to/local/dataset \ + --video-key observation.images.base +``` +After running, you can access the skill for any frame via: +```python +dataset = LeRobotDataset(repo_id="your/dataset") +item = dataset[100] +task_idx = item["task_index"] +skill_name = dataset.meta.tasks.iloc[task_idx.item()].name +``` +""" + +import argparse +import json +import re +import subprocess +import tempfile +import textwrap +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import cv2 +import torch +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +# Skill Annotation Data Structures + + +class Skill: + """Represents a single atomic skill/subtask in a demonstration.""" + + def __init__(self, name: str, start: float, end: float): + self.name = name + self.start = start # Start timestamp in seconds + self.end = end # End timestamp in seconds + + def to_dict(self) -> dict: + return {"name": self.name, "start": self.start, "end": self.end} + + @classmethod + def from_dict(cls, data: dict) -> "Skill": + return cls(name=data["name"], start=data["start"], end=data["end"]) + + def __repr__(self) -> str: + return f"Skill(name='{self.name}', start={self.start:.2f}, end={self.end:.2f})" + + +class EpisodeSkills: + """Container for all skills in an episode.""" + + def __init__(self, episode_index: int, description: str, skills: list[Skill]): + self.episode_index = episode_index + self.description = description + self.skills = skills + + def to_dict(self) -> dict: + return { + "episode_index": self.episode_index, + "description": self.description, + "skills": [s.to_dict() for s in self.skills], + } + + +# VLM Interface (Abstract Base Class for Modularity) + + +class BaseVLM(ABC): + """ + Abstract base class for Vision-Language Models. + + To add a new VLM: + 1. Create a subclass of BaseVLM + 2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods + 3. Register it in the VLM_REGISTRY dictionary + """ + + @abstractmethod + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + """Initialize the VLM with model name, device, and dtype.""" + pass + + @abstractmethod + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """ + Segment a video into atomic skills. + + Args: + video_path: Path to the video file + episode_duration: Total duration of the episode in seconds + coarse_goal: Optional high-level task description + + Returns: + List of Skill objects representing atomic manipulation skills + """ + pass + + @abstractmethod + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """ + Segment multiple videos into atomic skills in a single batch. + + Args: + video_paths: List of paths to video files + episode_durations: List of episode durations in seconds + coarse_goal: Optional high-level task description + + Returns: + List of skill lists, one for each video + """ + pass + + +def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str: + """Create the prompt for skill segmentation.""" + goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else "" + + return textwrap.dedent(f"""\ + # Role + You are a Robotics Vision System specializing in temporal action segmentation for robot manipulation demonstrations. + + # Task + {goal_context}Segment this robot demonstration video into short atomic manipulation skills. Each skill should: + - Last approximately 1-3 seconds + - Describe a clear, single action (e.g., "pick up object", "move arm left", "release gripper") + - Have precise start and end timestamps + + # Requirements + 1. **Atomic Actions**: Each skill should be a single, indivisible action + 2. **Complete Coverage**: Skills must cover the entire video duration with no gaps + 3. **Boundary Consistency**: The end of one skill equals the start of the next + 4. **Natural Language**: Use clear, descriptive names for each skill + 5. **Timestamps**: Use seconds (float) for all timestamps + + + + # Output Format + After your analysis, output ONLY valid JSON with this exact structure: + + ```json + {{ + "skills": [ + {{"name": "skill description", "start": 0.0, "end": 1.5}}, + {{"name": "another skill", "start": 1.5, "end": 3.2}} + ] + }} + ``` + + The first skill must start at 0.0 and the last skill must end at the video duration. + """) + + +# Qwen2-VL Implementation + + +class Qwen2VL(BaseVLM): + """Qwen2-VL model for skill segmentation.""" + + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + self.console = Console() + self.device = device + self.model_name = model_name + self.process_vision_info = process_vision_info + + self.console.print(f"[cyan]Loading Qwen2-VL model: {model_name}...[/cyan]") + + self.model = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """Segment video into skills using Qwen2-VL.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" + + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + )[0].strip() + + return self._parse_skills_response(response) + + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """Segment multiple videos into skills using Qwen2-VL in a batch.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + + # Create messages for each video + all_messages = [] + for video_path, duration in zip(video_paths, episode_durations): + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + all_messages.append(messages) + + # Process all videos in batch + all_texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + all_texts.append(text) + all_image_inputs.extend(image_inputs or []) + all_video_inputs.extend(video_inputs or []) + + inputs = self.processor( + text=all_texts, + images=all_image_inputs if all_image_inputs else None, + videos=all_video_inputs if all_video_inputs else None, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + responses = self.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse each response + all_skills = [] + for response in responses: + try: + skills = self._parse_skills_response(response.strip()) + all_skills.append(skills) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + all_skills.append([]) + + return all_skills + + def _parse_skills_response(self, response: str) -> list[Skill]: + """Parse the VLM response into Skill objects.""" + # Extract JSON from response + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + skills_data = data.get("skills", data) + if isinstance(skills_data, list): + return [Skill.from_dict(s) for s in skills_data] + except json.JSONDecodeError: + # Try to find JSON object in response + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + skills_data = data.get("skills", []) + return [Skill.from_dict(s) for s in skills_data] + + raise ValueError(f"Could not parse skills from response: {response[:200]}...") + + +# Qwen3-VL Implementation (MoE variant) + + +class Qwen3VL(BaseVLM): + """Qwen3-VL MoE model for skill segmentation.""" + + def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration + + self.console = Console() + self.device = device + self.model_name = model_name + self.process_vision_info = process_vision_info + + self.console.print(f"[cyan]Loading Qwen3-VL model: {model_name}...[/cyan]") + + self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained( + model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") + + def segment_skills( + self, video_path: Path, episode_duration: float, coarse_goal: str | None = None + ) -> list[Skill]: + """Segment video into skills using Qwen3-VL.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" + + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{episode_duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + response = self.processor.batch_decode( + [out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + )[0].strip() + + return self._parse_skills_response(response) + + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """Segment multiple videos into skills using Qwen3-VL in a batch.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + + # Create messages for each video + all_messages = [] + for video_path, duration in zip(video_paths, episode_durations): + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + all_messages.append(messages) + + # Process all videos in batch + all_texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + all_texts.append(text) + all_image_inputs.extend(image_inputs or []) + all_video_inputs.extend(video_inputs or []) + + inputs = self.processor( + text=all_texts, + images=all_image_inputs if all_image_inputs else None, + videos=all_video_inputs if all_video_inputs else None, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + responses = self.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse each response + all_skills = [] + for response in responses: + try: + skills = self._parse_skills_response(response.strip()) + all_skills.append(skills) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + all_skills.append([]) + + return all_skills + + def _parse_skills_response(self, response: str) -> list[Skill]: + """Parse the VLM response into Skill objects.""" + if "```json" in response: + response = response.split("```json")[1].split("```")[0] + elif "```" in response: + response = response.split("```")[1].split("```")[0] + + try: + data = json.loads(response) + skills_data = data.get("skills", data) + if isinstance(skills_data, list): + return [Skill.from_dict(s) for s in skills_data] + except json.JSONDecodeError: + match = re.search(r"\{.*\}", response, re.DOTALL) + if match: + data = json.loads(match.group()) + skills_data = data.get("skills", []) + return [Skill.from_dict(s) for s in skills_data] + + raise ValueError(f"Could not parse skills from response: {response[:200]}...") + + +# VLM Registry - Add new VLMs here + +VLM_REGISTRY: dict[str, type[BaseVLM]] = { + # Qwen2-VL variants + "Qwen/Qwen2-VL-2B-Instruct": Qwen2VL, + "Qwen/Qwen2-VL-7B-Instruct": Qwen2VL, + "Qwen/Qwen2-VL-72B-Instruct": Qwen2VL, + # Qwen3-VL variants (MoE) + "Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL, +} + + +def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM: + """ + Factory function to get the appropriate VLM based on model name. + + Args: + model_name: HuggingFace model identifier + device: Device to load model on + torch_dtype: Data type for model weights + + Returns: + Initialized VLM instance + + Raises: + ValueError: If model is not in registry + """ + # Check exact match first + if model_name in VLM_REGISTRY: + return VLM_REGISTRY[model_name](model_name, device, torch_dtype) + + # Check for partial matches (e.g., "qwen2" in model name) + model_lower = model_name.lower() + if "qwen3" in model_lower: + return Qwen3VL(model_name, device, torch_dtype) + elif "qwen2" in model_lower or "qwen-vl" in model_lower: + return Qwen2VL(model_name, device, torch_dtype) + + raise ValueError( + f"Unknown model: {model_name}. " + f"Supported models: {list(VLM_REGISTRY.keys())}. " + "Or implement a new VLM class inheriting from BaseVLM." + ) + + +# Video Extraction Utilities + +class VideoExtractor: + """Utilities for extracting and processing video segments from LeRobot datasets.""" + + def __init__(self, console: Console | None = None): + self.console = console or Console() + + def extract_episode_video( + self, + video_path: Path, + start_timestamp: float, + end_timestamp: float, + target_fps: int = 1, + ) -> Path: + """ + Extract a specific episode segment from a concatenated video file. + + Args: + video_path: Path to the source video file + start_timestamp: Start time in seconds + end_timestamp: End time in seconds + target_fps: Target frames per second for output + + Returns: + Path to the extracted temporary video file + """ + tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) + tmp_path = Path(tmp_file.name) + tmp_file.close() + + duration = end_timestamp - start_timestamp + + self.console.print( + f"[cyan]Extracting: {start_timestamp:.1f}s - {end_timestamp:.1f}s ({duration:.1f}s)[/cyan]" + ) + + cmd = [ + "ffmpeg", + "-i", + str(video_path), + "-ss", + str(start_timestamp), + "-t", + str(duration), + "-r", + str(target_fps), + "-c:v", + "libx264", + "-preset", + "ultrafast", + "-crf", + "23", + "-an", + "-y", + str(tmp_path), + ] + + try: + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"FFmpeg failed: {e}") from e + except FileNotFoundError: + raise RuntimeError("FFmpeg not found. Please install ffmpeg.") + + if not tmp_path.exists() or tmp_path.stat().st_size < 1024: + if tmp_path.exists(): + tmp_path.unlink() + raise RuntimeError("Video extraction produced invalid file") + + return tmp_path + + def get_video_duration(self, video_path: Path) -> float: + """Get duration of a video file in seconds.""" + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) or 30 + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return frame_count / fps + + +# Skill Annotation Pipeline +class SkillAnnotator: + """ + Main class for annotating LeRobot datasets with skill labels. + + This class orchestrates the full annotation pipeline: + 1. Load dataset + 2. Extract video segments for each episode + 3. Run VLM-based skill segmentation + 4. Update dataset task metadata + """ + + def __init__( + self, + vlm: BaseVLM, + video_extractor: VideoExtractor | None = None, + console: Console | None = None, + batch_size: int = 8, + ): + self.vlm = vlm + self.console = console or Console() + self.video_extractor = video_extractor or VideoExtractor(self.console) + self.batch_size = batch_size + + def annotate_dataset( + self, + dataset: LeRobotDataset, + video_key: str, + episodes: list[int] | None = None, + skip_existing: bool = False, + ) -> dict[int, EpisodeSkills]: + """ + Annotate all episodes in a dataset with skill labels using batched processing. + + Args: + dataset: LeRobot dataset to annotate + video_key: Key for video observations (e.g., "observation.images.base") + episodes: Specific episode indices to annotate (None = all) + skip_existing: Skip episodes that already have skill annotations + + Returns: + Dictionary mapping episode index to EpisodeSkills + """ + episode_indices = episodes or list(range(dataset.meta.total_episodes)) + annotations: dict[int, EpisodeSkills] = {} + + # Get coarse task description if available + coarse_goal = self._get_coarse_goal(dataset) + + print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...") + + # Process episodes in batches + for batch_start in range(0, len(episode_indices), self.batch_size): + batch_end = min(batch_start + self.batch_size, len(episode_indices)) + batch_episodes = episode_indices[batch_start:batch_end] + + print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...") + + try: + batch_annotations = self._annotate_episodes_batch( + dataset, batch_episodes, video_key, coarse_goal + ) + + for ep_idx, skills in batch_annotations.items(): + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) + except Exception as e: + self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]") + # Fallback: process episodes one by one + for ep_idx in batch_episodes: + try: + skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) + except Exception as e: + self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]") + + return annotations + + def _get_coarse_goal(self, dataset: LeRobotDataset) -> str: + """Extract or generate the coarse task description.""" + # Try to get from existing task metadata + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + # Get the first task description + first_task = dataset.meta.tasks.index[0] + if first_task: + return str(first_task) + + return "Perform the demonstrated manipulation task." + + def _annotate_episodes_batch( + self, + dataset: LeRobotDataset, + episode_indices: list[int], + video_key: str, + coarse_goal: str, + ) -> dict[int, list[Skill]]: + """Annotate multiple episodes with skill labels in a batch.""" + # Extract all videos for this batch + extracted_paths = [] + durations = [] + valid_episode_indices = [] + + for ep_idx in episode_indices: + try: + # Get video path and timestamps + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) + + if not video_path.exists(): + self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]") + continue + + # Get episode timestamps from metadata + ep = dataset.meta.episodes[ep_idx] + start_ts = float(ep[f"videos/{video_key}/from_timestamp"]) + end_ts = float(ep[f"videos/{video_key}/to_timestamp"]) + duration = end_ts - start_ts + + # Extract episode segment to temporary file + extracted_path = self.video_extractor.extract_episode_video( + video_path, start_ts, end_ts, target_fps=1 + ) + + extracted_paths.append(extracted_path) + durations.append(duration) + valid_episode_indices.append(ep_idx) + + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]") + continue + + if not extracted_paths: + return {} + + try: + # Run VLM skill segmentation in batch + all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal) + + # Map results back to episode indices + results = {} + for ep_idx, skills in zip(valid_episode_indices, all_skills): + results[ep_idx] = skills + + return results + + finally: + # Clean up all temporary files + for path in extracted_paths: + if path.exists(): + path.unlink() + + def _annotate_episode( + self, + dataset: LeRobotDataset, + episode_index: int, + video_key: str, + coarse_goal: str, + ) -> list[Skill]: + """Annotate a single episode with skill labels.""" + # Get video path and timestamps for this episode + video_path = dataset.root / dataset.meta.get_video_file_path(episode_index, video_key) + + if not video_path.exists(): + raise FileNotFoundError(f"Video not found: {video_path}") + + # Get episode timestamps from metadata + ep = dataset.meta.episodes[episode_index] + start_ts = float(ep[f"videos/{video_key}/from_timestamp"]) + end_ts = float(ep[f"videos/{video_key}/to_timestamp"]) + duration = end_ts - start_ts + + # Extract episode segment to temporary file + extracted_path = self.video_extractor.extract_episode_video( + video_path, start_ts, end_ts, target_fps=1 + ) + + try: + # Run VLM skill segmentation + skills = self.vlm.segment_skills(extracted_path, duration, coarse_goal) + return skills + finally: + # Clean up temporary file + if extracted_path.exists(): + extracted_path.unlink() + + +# Metadata Writer - Updates per-frame task_index based on skills + + +def get_skill_for_timestamp(skills: list[Skill], timestamp: float) -> Skill | None: + """ + Find which skill covers a given timestamp. + + Args: + skills: List of skills with start/end times + timestamp: Frame timestamp in seconds + + Returns: + The Skill that covers this timestamp, or None if not found + """ + for skill in skills: + if skill.start <= timestamp < skill.end: + return skill + # Handle the last frame (end boundary) + if timestamp >= skill.end and skill == skills[-1]: + return skill + return skills[-1] if skills else None # Fallback to last skill + + +def update_dataset_tasks( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], +) -> dict[str, int]: + """ + Register all unique skill names as new tasks in the dataset. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + + Returns: + Dictionary mapping skill name to task_index + """ + import pandas as pd + + from lerobot.datasets.utils import write_tasks + + console = Console() + + # Collect all unique skill names + all_skill_names: set[str] = set() + for episode_skills in annotations.values(): + for skill in episode_skills.skills: + all_skill_names.add(skill.name) + + console.print(f"[cyan]Found {len(all_skill_names)} unique skills[/cyan]") + + # Build new tasks DataFrame + # Start with existing tasks if any + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + existing_tasks = set(dataset.meta.tasks.index.tolist()) + max_task_idx = dataset.meta.tasks["task_index"].max() + else: + existing_tasks = set() + max_task_idx = -1 + + # Add new skills as tasks + new_tasks = all_skill_names - existing_tasks + if new_tasks: + new_task_data = [] + for i, skill_name in enumerate(sorted(new_tasks)): + new_task_data.append({ + "task": skill_name, + "task_index": max_task_idx + 1 + i, + }) + + new_tasks_df = pd.DataFrame(new_task_data).set_index("task") + + if dataset.meta.tasks is not None and len(dataset.meta.tasks) > 0: + dataset.meta.tasks = pd.concat([dataset.meta.tasks, new_tasks_df]) + else: + dataset.meta.tasks = new_tasks_df + + # Write updated tasks to disk + write_tasks(dataset.meta.tasks, dataset.root) + console.print(f"[green]✓ Added {len(new_tasks)} new tasks to tasks.parquet[/green]") + + # Build skill name to task_index mapping + skill_to_task_idx = { + task_name: int(dataset.meta.tasks.loc[task_name, "task_index"]) + for task_name in all_skill_names + } + + return skill_to_task_idx + + +def update_frame_task_indices( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], + skill_to_task_idx: dict[str, int], +) -> None: + """ + Update the task_index for each frame based on skill annotations. + + This reads the data parquet files, updates task_index based on which + skill covers each frame's timestamp, and writes back to disk. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + skill_to_task_idx: Mapping from skill name to task_index + """ + import pandas as pd + + console = Console() + + # Group episodes by their data file (chunk_index, file_index) + episodes_by_file: dict[tuple[int, int], list[int]] = {} + for ep_idx in annotations.keys(): + ep = dataset.meta.episodes[ep_idx] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + key = (chunk_idx, file_idx) + if key not in episodes_by_file: + episodes_by_file[key] = [] + episodes_by_file[key].append(ep_idx) + + # Process each data file + for (chunk_idx, file_idx), episode_indices in episodes_by_file.items(): + data_path = dataset.root / dataset.meta.data_path.format( + chunk_index=chunk_idx, file_index=file_idx + ) + + if not data_path.exists(): + console.print(f"[yellow]Warning: Data file not found: {data_path}[/yellow]") + continue + + # Read the parquet file + df = pd.read_parquet(data_path) + original_task_indices = df["task_index"].copy() + updated_count = 0 + + # Update task_index for each episode in this file + for ep_idx in episode_indices: + if ep_idx not in annotations: + continue + + episode_skills = annotations[ep_idx] + skills = episode_skills.skills + + # Get episode frame range + ep = dataset.meta.episodes[ep_idx] + ep_from = ep["dataset_from_index"] + ep_to = ep["dataset_to_index"] + + # Filter to rows for this episode + episode_mask = (df["index"] >= ep_from) & (df["index"] < ep_to) + episode_rows = df.loc[episode_mask] + + # Update task_index for each frame based on its timestamp + for idx, row in episode_rows.iterrows(): + timestamp = row["timestamp"] + skill = get_skill_for_timestamp(skills, timestamp) + + if skill and skill.name in skill_to_task_idx: + new_task_idx = skill_to_task_idx[skill.name] + if df.at[idx, "task_index"] != new_task_idx: + df.at[idx, "task_index"] = new_task_idx + updated_count += 1 + + # Write back if any changes were made + if updated_count > 0: + df.to_parquet(data_path, engine="pyarrow", compression="snappy", index=False) + console.print( + f"[green]✓ Updated {updated_count} frame task_indices in {data_path.name}[/green]" + ) + + +def save_skill_annotations( + dataset: LeRobotDataset, + annotations: dict[int, EpisodeSkills], + output_path: Path | None = None, +) -> None: + """ + Save skill annotations to the dataset, updating both: + 1. The tasks.parquet with new skill names + 2. The per-frame task_index in data parquet files + + This function updates the task field for each frame based on + which skill covers that frame's timestamp. + + Args: + dataset: The LeRobot dataset to update + annotations: Dictionary of episode skills + output_path: Optional custom output path for the annotations JSON + """ + console = Console() + + if not annotations: + console.print("[yellow]No annotations to save[/yellow]") + return + + # Step 1: Register all unique skills as tasks + console.print("[cyan]Registering skills as tasks...[/cyan]") + skill_to_task_idx = update_dataset_tasks(dataset, annotations) + + # Step 2: Update per-frame task_index in data parquet files + console.print("[cyan]Updating per-frame task indices...[/cyan]") + update_frame_task_indices(dataset, annotations, skill_to_task_idx) + + # Step 3: Also save the raw skill annotations as JSON for reference + skills_path = output_path or (dataset.root / "meta" / "skills.json") + skills_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing skills data if it exists and is not empty + existing_skills_data = None + if skills_path.exists(): + try: + with open(skills_path, "r") as f: + existing_skills_data = json.load(f) + if existing_skills_data and len(existing_skills_data.get("episodes", {})) > 0: + console.print(f"[cyan]Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging...[/cyan]") + except (json.JSONDecodeError, IOError): + console.print("[yellow]Warning: Could not load existing skills.json, will create new file[/yellow]") + existing_skills_data = None + + # Prepare new annotations + new_episodes = {str(ep_idx): ann.to_dict() for ep_idx, ann in annotations.items()} + + # Merge with existing data if available + if existing_skills_data: + # Preserve existing episodes that are not being updated + merged_episodes = existing_skills_data.get("episodes", {}).copy() + merged_episodes.update(new_episodes) + + # Merge skill_to_task_index mappings + merged_skill_to_task = existing_skills_data.get("skill_to_task_index", {}).copy() + merged_skill_to_task.update(skill_to_task_idx) + + # Use existing coarse_description if available, otherwise use new one + coarse_desc = existing_skills_data.get("coarse_description", annotations[next(iter(annotations))].description) + + skills_data = { + "coarse_description": coarse_desc, + "skill_to_task_index": merged_skill_to_task, + "episodes": merged_episodes, + } + console.print(f"[cyan]Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}[/cyan]") + else: + # No existing data, create new + skills_data = { + "coarse_description": annotations[next(iter(annotations))].description, + "skill_to_task_index": skill_to_task_idx, + "episodes": new_episodes, + } + + with open(skills_path, "w") as f: + json.dump(skills_data, f, indent=2) + + console.print(f"[green]✓ Saved skill annotations to {skills_path}[/green]") + + # Reload the dataset's hf_dataset to reflect changes + dataset._lazy_loading = True + + +def load_skill_annotations(dataset_root: Path) -> dict | None: + """Load existing skill annotations from a dataset.""" + skills_path = dataset_root / "meta" / "skills.json" + if skills_path.exists(): + with open(skills_path) as f: + return json.load(f) + return None + + +# Main Entry Point + + +def main(): + """Main entry point for the skill annotation script.""" + parser = argparse.ArgumentParser( + description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=textwrap.dedent("""\ + Examples: + # Annotate a HuggingFace Hub dataset + python annotate.py --repo-id user/dataset --video-key observation.images.base + + # Annotate a local dataset with custom batch size + python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16 + + # Use a specific model + python annotate.py --repo-id user/dataset --video-key observation.images.base \\ + --model Qwen/Qwen2-VL-7B-Instruct + + # Push annotated dataset to Hub + python annotate.py --repo-id user/dataset --video-key observation.images.base --push-to-hub + """), + ) + + # Data source (mutually exclusive) + data_group = parser.add_mutually_exclusive_group(required=True) + data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset") + data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID") + + # Required arguments + parser.add_argument( + "--video-key", + type=str, + required=True, + help="Video observation key (e.g., 'observation.images.base')", + ) + + # Model configuration + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2-VL-7B-Instruct", + help="VLM model to use for skill segmentation (default: Qwen/Qwen2-VL-7B-Instruct)", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run model on (default: cuda)", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Model dtype (default: bfloat16)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=8, + help="Number of episodes to process in each batch (default: 8)", + ) + + # Episode selection + parser.add_argument( + "--episodes", + type=int, + nargs="+", + help="Specific episode indices to annotate (default: all)", + ) + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip episodes that already have annotations", + ) + + # Output options + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Push annotated dataset to HuggingFace Hub", + ) + parser.add_argument( + "--output-path", + type=str, + help="Custom output path for annotations JSON", + ) + + args = parser.parse_args() + console = Console() + + # Validate arguments + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + torch_dtype = dtype_map[args.dtype] + + # Load dataset + console.print("[cyan]Loading dataset...[/cyan]") + if args.data_dir: + dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir, download_videos=False) + else: + dataset = LeRobotDataset(repo_id=args.repo_id, download_videos=True) + + console.print(f"[green]✓ Loaded dataset with {dataset.meta.total_episodes} episodes[/green]") + + # Validate video key + if args.video_key not in dataset.meta.video_keys: + available = ", ".join(dataset.meta.video_keys) + console.print(f"[red]Error: Video key '{args.video_key}' not found. Available: {available}[/red]") + return + + # Initialize VLM + console.print(f"[cyan]Initializing VLM: {args.model}...[/cyan]") + vlm = get_vlm(args.model, args.device, torch_dtype) + + # Create annotator and run annotation + annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size) + console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]") + annotations = annotator.annotate_dataset( + dataset=dataset, + video_key=args.video_key, + episodes=args.episodes, + skip_existing=args.skip_existing, + ) + + # Save annotations + output_path = Path(args.output_path) if args.output_path else None + save_skill_annotations(dataset, annotations, output_path) + + # Summary + total_skills = sum(len(ann.skills) for ann in annotations.values()) + console.print(f"\n[bold green]✓ Annotation complete![/bold green]") + console.print(f" Episodes annotated: {len(annotations)}") + console.print(f" Total skills identified: {total_skills}") + + # Push to hub if requested + if args.push_to_hub: + if args.data_dir: + console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]") + else: + console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]") + try: + dataset.push_to_hub(push_videos=False) + console.print(f"[green]✓ Pushed to {args.repo_id}[/green]") + except Exception as e: + console.print(f"[red]Push failed: {e}[/red]") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/pi05_full/configuration_pi05.py b/src/lerobot/policies/pi05_full/configuration_pi05.py new file mode 100644 index 000000000..b96e6d196 --- /dev/null +++ b/src/lerobot/policies/pi05_full/configuration_pi05.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +DEFAULT_IMAGE_SIZE = 224 + + +@PreTrainedConfig.register_subclass("pi05") +@dataclass +class PI05Config(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + n_obs_steps: int = 1 + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state + "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Finetuning settings + freeze_vision_encoder: bool = False # Freeze only the vision encoder + train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = OBS_IMAGES + f".empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if OBS_STATE not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features[OBS_STATE] = state_feature + + if ACTION not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features[ACTION] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py new file mode 100644 index 000000000..11d8b4d68 --- /dev/null +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -0,0 +1,1272 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from typing_extensions import Unpack + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI05.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, + freeze_vision_encoder: bool = False, + train_expert_only: bool = False, + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + self._set_requires_grad() + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def _set_requires_grad(self): + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for param in self.paligemma.vision_tower.parameters(): + param.requires_grad = False + if self.train_expert_only: + self.paligemma.eval() + for param in self.paligemma.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + if self.train_expert_only: + self.paligemma.eval() + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI05 PyTorch model.""" + + def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True], + precision=config.dtype, + image_size=config.image_resolution[0], + freeze_vision_encoder=config.freeze_vision_encoder, + train_expert_only=config.train_expert_only, + ) + + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) + + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI05Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI05Pytorch model") + + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, tokens, masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed timestep using sine-cosine positional encoding + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss.""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions( + self, + images, + img_masks, + tokens, + masks, + noise=None, + num_steps=None, + **kwargs: Unpack[ActionSelectKwargs], + ) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = tokens.shape[0] + device = tokens.device + + if noise is None: + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, + ) # Use config max_action_dim for internal processing + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + + x_t = noise + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) + + if self._rtc_enabled(): + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + + x_t = x_t + dt * v_t + + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) + + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + + +class PI05Policy(PreTrainedPolicy): + """PI05 Policy for LeRobot.""" + + config_class = PI05Config + name = "pi05" + + def __init__( + self, + config: PI05Config, + **kwargs, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the core PI05 model + self.init_rtc_processor() + self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + actions = self.prepare_action(batch) + + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss_dict = { + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict diff --git a/src/lerobot/policies/pi05_full/processor_pi05.py b/src/lerobot/policies/pi05_full/processor_pi05.py new file mode 100644 index 000000000..e29bc4c23 --- /dev/null +++ b/src/lerobot/policies/pi05_full/processor_pi05.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05.modeling_pi05 import pad_vector +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +@dataclass +class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI05") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi05_pre_post_processors( + config: PI05Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + )