diff --git a/examples/dataset/annotate.py b/examples/dataset/annotate.py index 1ecb49f82..895145def 100644 --- a/examples/dataset/annotate.py +++ b/examples/dataset/annotate.py @@ -30,7 +30,6 @@ The pipeline: Supported VLMs (modular design allows easy extension): - Qwen2-VL (default): "Qwen/Qwen2-VL-7B-Instruct" - Qwen3-VL: "Qwen/Qwen3-VL-30B-A3B-Instruct" -- SmolVLM: "HuggingFaceTB/SmolVLM-Instruct" Usage: ```bash @@ -52,7 +51,7 @@ After running, you can access the skill for any frame via: dataset = LeRobotDataset(repo_id="your/dataset") item = dataset[100] task_idx = item["task_index"] -skill_name = dataset.meta.tasks.iloc[task_idx].name +skill_name = dataset.meta.tasks.iloc[task_idx.item()].name ``` """ @@ -125,7 +124,7 @@ class BaseVLM(ABC): To add a new VLM: 1. Create a subclass of BaseVLM - 2. Implement the `__init__` and `segment_skills` methods + 2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods 3. Register it in the VLM_REGISTRY dictionary """ @@ -151,6 +150,23 @@ class BaseVLM(ABC): """ pass + @abstractmethod + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """ + Segment multiple videos into atomic skills in a single batch. + + Args: + video_paths: List of paths to video files + episode_durations: List of episode durations in seconds + coarse_goal: Optional high-level task description + + Returns: + List of skill lists, one for each video + """ + pass + def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str: """Create the prompt for skill segmentation.""" @@ -258,6 +274,71 @@ class Qwen2VL(BaseVLM): return self._parse_skills_response(response) + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """Segment multiple videos into skills using Qwen2-VL in a batch.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + + # Create messages for each video + all_messages = [] + for video_path, duration in zip(video_paths, episode_durations): + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + all_messages.append(messages) + + # Process all videos in batch + all_texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + all_texts.append(text) + all_image_inputs.extend(image_inputs or []) + all_video_inputs.extend(video_inputs or []) + + inputs = self.processor( + text=all_texts, + images=all_image_inputs if all_image_inputs else None, + videos=all_video_inputs if all_video_inputs else None, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + responses = self.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse each response + all_skills = [] + for response in responses: + try: + skills = self._parse_skills_response(response.strip()) + all_skills.append(skills) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + all_skills.append([]) + + return all_skills + def _parse_skills_response(self, response: str) -> list[Skill]: """Parse the VLM response into Skill objects.""" # Extract JSON from response @@ -349,6 +430,71 @@ class Qwen3VL(BaseVLM): return self._parse_skills_response(response) + def segment_skills_batch( + self, video_paths: list[Path], episode_durations: list[float], coarse_goal: str | None = None + ) -> list[list[Skill]]: + """Segment multiple videos into skills using Qwen3-VL in a batch.""" + prompt = create_skill_segmentation_prompt(coarse_goal) + + # Create messages for each video + all_messages = [] + for video_path, duration in zip(video_paths, episode_durations): + duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}" + messages = [ + {"role": "system", "content": [{"type": "text", "text": prompt}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": str(video_path), "fps": 1.0}, + { + "type": "text", + "text": f"Video duration: {duration_str} (~{duration:.1f}s). Segment into atomic skills.", + }, + ], + }, + ] + all_messages.append(messages) + + # Process all videos in batch + all_texts = [] + all_image_inputs = [] + all_video_inputs = [] + + for messages in all_messages: + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = self.process_vision_info(messages) + all_texts.append(text) + all_image_inputs.extend(image_inputs or []) + all_video_inputs.extend(video_inputs or []) + + inputs = self.processor( + text=all_texts, + images=all_image_inputs if all_image_inputs else None, + videos=all_video_inputs if all_video_inputs else None, + padding=True, + return_tensors="pt", + ).to(self.device) + + with torch.no_grad(): + generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) + + responses = self.processor.batch_decode( + [out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)], + skip_special_tokens=True, + ) + + # Parse each response + all_skills = [] + for response in responses: + try: + skills = self._parse_skills_response(response.strip()) + all_skills.append(skills) + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + all_skills.append([]) + + return all_skills + def _parse_skills_response(self, response: str) -> list[Skill]: """Parse the VLM response into Skill objects.""" if "```json" in response: @@ -371,137 +517,6 @@ class Qwen3VL(BaseVLM): raise ValueError(f"Could not parse skills from response: {response[:200]}...") -# ============================================================================= -# SmolVLM Implementation -# ============================================================================= - - -class SmolVLM(BaseVLM): - """SmolVLM model for skill segmentation (lighter weight alternative).""" - - def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16): - from transformers import AutoModelForVision2Seq, AutoProcessor - - self.console = Console() - self.device = device - self.model_name = model_name - - self.console.print(f"[cyan]Loading SmolVLM model: {model_name}...[/cyan]") - - self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - self.model = AutoModelForVision2Seq.from_pretrained( - model_name, - torch_dtype=torch_dtype, - # _attn_implementation="flash_attention_2" if device == "cuda" else "eager", - ).to(device) - - self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]") - - def segment_skills( - self, video_path: Path, episode_duration: float, coarse_goal: str | None = None - ) -> list[Skill]: - """Segment video into skills using SmolVLM with frame sampling.""" - import PIL.Image - - # SmolVLM works with images, so we sample frames from the video - frames = self._extract_frames(video_path, target_fps=1) - - if not frames: - raise ValueError(f"Could not extract frames from {video_path}") - - prompt = create_skill_segmentation_prompt(coarse_goal) - duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}" - - # Sample frames (up to 8 frames to avoid context overflow) - frame_indices = self._select_frame_indices(len(frames), max_frames=8) - - # Convert frames to PIL images - pil_images = [ - PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB)) - for idx in frame_indices - ] - - # Create message content with image placeholders - content = [{"type": "text", "text": prompt}] - - # Add image placeholders (one for each frame) - for _ in frame_indices: - content.append({"type": "image"}) - - content.append( - { - "type": "text", - "text": f"These are {len(frame_indices)} sampled frames from a {duration_str} video. Segment into atomic skills.", - } - ) - - messages = [{"role": "user", "content": content}] - - # Apply chat template to get the prompt - prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) - - # Process inputs with both text and images - inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt") - inputs = inputs.to(self.device) - - with torch.no_grad(): - generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) - - response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() - - return self._parse_skills_response(response, episode_duration) - - def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list: - """Extract frames from video at target FPS.""" - cap = cv2.VideoCapture(str(video_path)) - frames = [] - fps = cap.get(cv2.CAP_PROP_FPS) or 30 - frame_interval = int(fps / target_fps) - - frame_count = 0 - while True: - ret, frame = cap.read() - if not ret: - break - if frame_count % frame_interval == 0: - frames.append(frame) - frame_count += 1 - - cap.release() - return frames - - def _select_frame_indices(self, total_frames: int, max_frames: int = 8) -> list[int]: - """Select evenly spaced frame indices.""" - if total_frames <= max_frames: - return list(range(total_frames)) - step = total_frames / max_frames - return [int(i * step) for i in range(max_frames)] - - def _parse_skills_response(self, response: str, episode_duration: float) -> list[Skill]: - """Parse the VLM response into Skill objects.""" - if "```json" in response: - response = response.split("```json")[1].split("```")[0] - elif "```" in response: - response = response.split("```")[1].split("```")[0] - - try: - data = json.loads(response) - skills_data = data.get("skills", data) - breakpoint() - if isinstance(skills_data, list): - return [Skill.from_dict(s) for s in skills_data] - except json.JSONDecodeError: - match = re.search(r"\{.*\}", response, re.DOTALL) - if match: - data = json.loads(match.group()) - skills_data = data.get("skills", []) - return [Skill.from_dict(s) for s in skills_data] - - # Fallback: create a single skill covering the whole episode - self.console.print("[yellow]Warning: Could not parse skills, creating single skill[/yellow]") - return [Skill(name="perform manipulation", start=0.0, end=episode_duration)] - - # ============================================================================= # VLM Registry - Add new VLMs here # ============================================================================= @@ -513,10 +528,6 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = { "Qwen/Qwen2-VL-72B-Instruct": Qwen2VL, # Qwen3-VL variants (MoE) "Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL, - # SmolVLM variants - "HuggingFaceTB/SmolVLM-Instruct": SmolVLM, - "HuggingFaceTB/SmolVLM-256M-Instruct": SmolVLM, - "HuggingFaceTB/SmolVLM-500M-Instruct": SmolVLM, } @@ -545,8 +556,6 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to return Qwen3VL(model_name, device, torch_dtype) elif "qwen2" in model_lower or "qwen-vl" in model_lower: return Qwen2VL(model_name, device, torch_dtype) - elif "smolvlm" in model_lower: - return SmolVLM(model_name, device, torch_dtype) raise ValueError( f"Unknown model: {model_name}. " @@ -660,10 +669,12 @@ class SkillAnnotator: vlm: BaseVLM, video_extractor: VideoExtractor | None = None, console: Console | None = None, + batch_size: int = 8, ): self.vlm = vlm self.console = console or Console() self.video_extractor = video_extractor or VideoExtractor(self.console) + self.batch_size = batch_size def annotate_dataset( self, @@ -673,7 +684,7 @@ class SkillAnnotator: skip_existing: bool = False, ) -> dict[int, EpisodeSkills]: """ - Annotate all episodes in a dataset with skill labels. + Annotate all episodes in a dataset with skill labels using batched processing. Args: dataset: LeRobot dataset to annotate @@ -690,32 +701,45 @@ class SkillAnnotator: # Get coarse task description if available coarse_goal = self._get_coarse_goal(dataset) - # with Progress( - # SpinnerColumn(), - # TextColumn("[progress.description]{task.description}"), - # console=self.console, - # ) as progress: - # task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices)) - print(f"Annotating {len(episode_indices)} episodes...") + print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...") - for ep_idx in episode_indices: - # progress.update(task, description=f"Processing episode {ep_idx}...") - print(f"Processing episode {ep_idx}...") + # Process episodes in batches + for batch_start in range(0, len(episode_indices), self.batch_size): + batch_end = min(batch_start + self.batch_size, len(episode_indices)) + batch_episodes = episode_indices[batch_start:batch_end] + + print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...") try: - skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) - annotations[ep_idx] = EpisodeSkills( - episode_index=ep_idx, - description=coarse_goal, - skills=skills, - ) - self.console.print( - f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + batch_annotations = self._annotate_episodes_batch( + dataset, batch_episodes, video_key, coarse_goal ) + + for ep_idx, skills in batch_annotations.items(): + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) except Exception as e: - self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]") - - # progress.advance(task) + self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]") + # Fallback: process episodes one by one + for ep_idx in batch_episodes: + try: + skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) + except Exception as e: + self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]") return annotations @@ -730,6 +754,67 @@ class SkillAnnotator: return "Perform the demonstrated manipulation task." + def _annotate_episodes_batch( + self, + dataset: LeRobotDataset, + episode_indices: list[int], + video_key: str, + coarse_goal: str, + ) -> dict[int, list[Skill]]: + """Annotate multiple episodes with skill labels in a batch.""" + # Extract all videos for this batch + extracted_paths = [] + durations = [] + valid_episode_indices = [] + + for ep_idx in episode_indices: + try: + # Get video path and timestamps + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) + + if not video_path.exists(): + self.console.print(f"[yellow]Warning: Video not found for episode {ep_idx}[/yellow]") + continue + + # Get episode timestamps from metadata + ep = dataset.meta.episodes[ep_idx] + start_ts = float(ep[f"videos/{video_key}/from_timestamp"]) + end_ts = float(ep[f"videos/{video_key}/to_timestamp"]) + duration = end_ts - start_ts + + # Extract episode segment to temporary file + extracted_path = self.video_extractor.extract_episode_video( + video_path, start_ts, end_ts, target_fps=1 + ) + + extracted_paths.append(extracted_path) + durations.append(duration) + valid_episode_indices.append(ep_idx) + + except Exception as e: + self.console.print(f"[yellow]Warning: Failed to extract video for episode {ep_idx}: {e}[/yellow]") + continue + + if not extracted_paths: + return {} + + try: + # Run VLM skill segmentation in batch + all_skills = self.vlm.segment_skills_batch(extracted_paths, durations, coarse_goal) + + # Map results back to episode indices + results = {} + for ep_idx, skills in zip(valid_episode_indices, all_skills): + results[ep_idx] = skills + + return results + + finally: + # Clean up all temporary files + for path in extracted_paths: + if path.exists(): + path.unlink() + def _annotate_episode( self, dataset: LeRobotDataset, @@ -1007,15 +1092,15 @@ def load_skill_annotations(dataset_root: Path) -> dict | None: def main(): """Main entry point for the skill annotation script.""" parser = argparse.ArgumentParser( - description="Automatic skill annotation for LeRobot datasets using VLMs", + description="Automatic skill annotation for LeRobot datasets using VLMs (with batched processing)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=textwrap.dedent("""\ Examples: # Annotate a HuggingFace Hub dataset python annotate.py --repo-id user/dataset --video-key observation.images.base - # Annotate a local dataset - python annotate.py --data-dir /path/to/dataset --video-key observation.images.base + # Annotate a local dataset with custom batch size + python annotate.py --data-dir /path/to/dataset --video-key observation.images.base --batch-size 16 # Use a specific model python annotate.py --repo-id user/dataset --video-key observation.images.base \\ @@ -1059,6 +1144,12 @@ def main(): choices=["bfloat16", "float16", "float32"], help="Model dtype (default: bfloat16)", ) + parser.add_argument( + "--batch-size", + type=int, + default=8, + help="Number of episodes to process in each batch (default: 8)", + ) # Episode selection parser.add_argument( @@ -1116,7 +1207,8 @@ def main(): vlm = get_vlm(args.model, args.device, torch_dtype) # Create annotator and run annotation - annotator = SkillAnnotator(vlm=vlm, console=console) + annotator = SkillAnnotator(vlm=vlm, console=console, batch_size=args.batch_size) + console.print(f"[cyan]Processing with batch size: {args.batch_size}[/cyan]") annotations = annotator.annotate_dataset( dataset=dataset, video_key=args.video_key,