diff --git a/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py b/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py index 74f1a7d54..bf5a98e52 100644 --- a/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py +++ b/src/lerobot/policies/pi05_full/annotate/subtask_annotate.py @@ -333,12 +333,14 @@ class Qwen2VL(BaseVLM): # Parse each response all_skills = [] - for response in responses: + for idx, response in enumerate(responses): try: skills = self._parse_skills_response(response.strip()) + if not skills: + self.console.print(f"[yellow]Warning: No skills parsed from response for video {idx}[/yellow]") all_skills.append(skills) except Exception as e: - self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + self.console.print(f"[yellow]Warning: Failed to parse response for video {idx}: {e}[/yellow]") all_skills.append([]) return all_skills @@ -487,12 +489,14 @@ class Qwen3VL(BaseVLM): # Parse each response all_skills = [] - for response in responses: + for idx, response in enumerate(responses): try: skills = self._parse_skills_response(response.strip()) + if not skills: + self.console.print(f"[yellow]Warning: No skills parsed from response for video {idx}[/yellow]") all_skills.append(skills) except Exception as e: - self.console.print(f"[yellow]Warning: Failed to parse response: {e}[/yellow]") + self.console.print(f"[yellow]Warning: Failed to parse response for video {idx}: {e}[/yellow]") all_skills.append([]) return all_skills @@ -690,10 +694,26 @@ class SkillAnnotator: """ episode_indices = episodes or list(range(dataset.meta.total_episodes)) annotations: dict[int, EpisodeSkills] = {} + failed_episodes: dict[int, str] = {} # Track failed episodes with error messages # Get coarse task description if available coarse_goal = self._get_coarse_goal(dataset) + # Filter out episodes that already have annotations if skip_existing is True + if skip_existing: + existing_annotations = load_skill_annotations(dataset.root) + if existing_annotations and "episodes" in existing_annotations: + existing_episode_indices = {int(idx) for idx in existing_annotations["episodes"].keys()} + original_count = len(episode_indices) + episode_indices = [ep for ep in episode_indices if ep not in existing_episode_indices] + skipped_count = original_count - len(episode_indices) + if skipped_count > 0: + self.console.print(f"[cyan]Skipping {skipped_count} episodes with existing annotations[/cyan]") + + if not episode_indices: + self.console.print("[yellow]No episodes to annotate (all already annotated)[/yellow]") + return annotations + print(f"Annotating {len(episode_indices)} episodes in batches of {self.batch_size}...") # Process episodes in batches @@ -708,21 +728,9 @@ class SkillAnnotator: dataset, batch_episodes, video_key, coarse_goal ) - for ep_idx, skills in batch_annotations.items(): - annotations[ep_idx] = EpisodeSkills( - episode_index=ep_idx, - description=coarse_goal, - skills=skills, - ) - self.console.print( - f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" - ) - except Exception as e: - self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]") - # Fallback: process episodes one by one for ep_idx in batch_episodes: - try: - skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + if ep_idx in batch_annotations and batch_annotations[ep_idx]: + skills = batch_annotations[ep_idx] annotations[ep_idx] = EpisodeSkills( episode_index=ep_idx, description=coarse_goal, @@ -731,8 +739,63 @@ class SkillAnnotator: self.console.print( f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" ) - except Exception as e: - self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]") + else: + failed_episodes[ep_idx] = "Empty or missing skills from batch processing" + self.console.print(f"[yellow]⚠ Episode {ep_idx}: No skills extracted, will retry[/yellow]") + except Exception as e: + self.console.print(f"[red]✗ Batch failed: {e}. Falling back to single-episode processing...[/red]") + # Fallback: process episodes one by one + for ep_idx in batch_episodes: + try: + skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + if skills: + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]" + ) + else: + failed_episodes[ep_idx] = "Empty skills list from single-episode processing" + self.console.print(f"[yellow]⚠ Episode {ep_idx}: No skills extracted, will retry[/yellow]") + except Exception as ep_error: + failed_episodes[ep_idx] = str(ep_error) + self.console.print(f"[yellow]⚠ Episode {ep_idx} failed: {ep_error}, will retry[/yellow]") + + # Retry failed episodes one more time + if failed_episodes: + self.console.print(f"\n[cyan]Retrying {len(failed_episodes)} failed episodes...[/cyan]") + retry_count = 0 + for ep_idx, error_msg in list(failed_episodes.items()): + self.console.print(f"[cyan]Retry attempt for episode {ep_idx} (previous error: {error_msg})[/cyan]") + try: + skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal) + if skills: + annotations[ep_idx] = EpisodeSkills( + episode_index=ep_idx, + description=coarse_goal, + skills=skills, + ) + self.console.print( + f"[green]✓ Episode {ep_idx} (retry): {len(skills)} skills identified[/green]" + ) + del failed_episodes[ep_idx] + retry_count += 1 + else: + self.console.print(f"[red]✗ Episode {ep_idx} (retry): Still no skills extracted[/red]") + except Exception as retry_error: + failed_episodes[ep_idx] = str(retry_error) + self.console.print(f"[red]✗ Episode {ep_idx} (retry) failed: {retry_error}[/red]") + + if retry_count > 0: + self.console.print(f"[green]Successfully recovered {retry_count} episodes on retry[/green]") + + if failed_episodes: + self.console.print(f"\n[red]⚠ Warning: {len(failed_episodes)} episodes still failed after retry:[/red]") + for ep_idx, error_msg in failed_episodes.items(): + self.console.print(f" Episode {ep_idx}: {error_msg}") return annotations