From 1ffdc6f49e6e7e63abe962365d7673351c118548 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 18 Nov 2025 15:28:40 +0100 Subject: [PATCH] subtasks --- .../dataset_annotation/subtask_annotation.py | 299 +++++++++++++++--- scripts/visualize_sarm_predictions.py | 100 +++++- .../policies/sarm/configuration_sarm.py | 5 +- src/lerobot/policies/sarm/modeling_sarm.py | 51 ++- src/lerobot/policies/sarm/processor_sarm.py | 225 +++++++++++++ 5 files changed, 624 insertions(+), 56 deletions(-) diff --git a/examples/dataset_annotation/subtask_annotation.py b/examples/dataset_annotation/subtask_annotation.py index 5689b6a09..0df1c942f 100644 --- a/examples/dataset_annotation/subtask_annotation.py +++ b/examples/dataset_annotation/subtask_annotation.py @@ -42,13 +42,30 @@ Usage: # Install dependencies pip install transformers torch qwen-vl-utils accelerate -# Annotate and push to hub: +# Sequential processing (single GPU): python subtask_annotation.py \\ --repo-id pepijn223/mydataset \\ --subtasks "reach,grasp,lift,place" \\ --video-key observation.images.base \\ --push-to-hub +# Parallel processing (4 GPUs): +python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp,lift,place" \\ + --video-key observation.images.base \\ + --num-workers 4 \\ + --push-to-hub + +# Parallel with specific GPU IDs: +python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp,lift,place" \\ + --video-key observation.images.base \\ + --num-workers 2 \\ + --gpu-ids 0 2 \\ + --push-to-hub + """ import argparse @@ -56,6 +73,8 @@ import json import time from pathlib import Path from typing import Optional +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed +import multiprocessing as mp import pandas as pd import torch @@ -63,7 +82,7 @@ from pydantic import BaseModel, Field from qwen_vl_utils import process_vision_info from rich.console import Console from rich.panel import Panel -from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn from rich.tree import Tree from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor @@ -177,7 +196,7 @@ class VideoAnnotator: file_path: Path, start_timestamp: float, end_timestamp: float, - target_fps: int = 2 + target_fps: int = 1 ) -> Path: """ Extract a specific episode segment from concatenated video. @@ -187,7 +206,7 @@ class VideoAnnotator: file_path: Path to the concatenated video file start_timestamp: Starting timestamp in seconds (within this video file) end_timestamp: Ending timestamp in seconds (within this video file) - target_fps: Target FPS (default: 2 for faster processing) + target_fps: Target FPS (default: 1 for faster processing) Returns: Path to extracted video file @@ -683,6 +702,77 @@ def process_single_episode( return ep_idx, None, str(e) +def worker_process_episodes( + worker_id: int, + gpu_id: int, + episode_indices: list[int], + repo_id: str, + video_key: str, + subtask_list: list[str], + model_name: str, + torch_dtype: torch.dtype, +) -> dict[int, SubtaskAnnotation]: + """ + Worker function for parallel processing across GPUs. + + Args: + worker_id: Worker ID for logging + gpu_id: GPU device ID to use + episode_indices: List of episode indices to process + repo_id: Dataset repo ID + video_key: Video key to use + subtask_list: List of subtask names + model_name: Model name to load + torch_dtype: Model dtype + + Returns: + Dictionary of episode_idx -> SubtaskAnnotation + """ + # Set GPU device + device = f"cuda:{gpu_id}" + + # Initialize console for this worker + console = Console() + console.print(f"[cyan]Worker {worker_id} starting on GPU {gpu_id} with {len(episode_indices)} episodes[/cyan]") + + # Load dataset (this is lightweight, just metadata) + dataset = LeRobotDataset(repo_id, download_videos=False) + fps = dataset.fps + + # Initialize annotator for this worker + annotator = VideoAnnotator( + subtask_list=subtask_list, + model_name=model_name, + device=device, + torch_dtype=torch_dtype + ) + + # Process assigned episodes + annotations = {} + + for i, ep_idx in enumerate(episode_indices): + console.print(f"[cyan]Worker {worker_id} | Episode {ep_idx} ({i+1}/{len(episode_indices)})[/cyan]") + + result_ep_idx, annotation, error = process_single_episode( + ep_idx, + dataset.root, + dataset.meta, + video_key, + fps, + annotator, + console + ) + + if error: + console.print(f"[red]Worker {worker_id} | ✗ Failed episode {result_ep_idx}: {error}[/red]") + elif annotation: + annotations[result_ep_idx] = annotation + console.print(f"[green]Worker {worker_id} | ✓ Completed episode {result_ep_idx}[/green]") + + console.print(f"[bold green]Worker {worker_id} completed {len(annotations)}/{len(episode_indices)} episodes[/bold green]") + return annotations + + def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]: """ Compute average temporal proportion for each subtask across all episodes. @@ -742,16 +832,41 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: + # Sequential processing (single GPU): + python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp,lift,place" \\ + --video-key observation.images.top \\ + --push-to-hub + + # Parallel processing with 4 GPUs (4x speedup): + python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp,lift,place" \\ + --video-key observation.images.top \\ + --num-workers 4 \\ + --push-to-hub + + # Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3): + python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp,lift,place" \\ + --video-key observation.images.top \\ + --num-workers 3 \\ + --gpu-ids 0 2 3 \\ + --push-to-hub + # List available cameras: - python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --max-episodes 0 - - # Annotate with specific camera: - python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --push-to-hub - - # Use custom model: - python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --model Qwen/Qwen3-VL-30B-A3B-Instruct --push-to-hub + python subtask_annotation.py \\ + --repo-id pepijn223/mydataset \\ + --subtasks "reach,grasp" \\ + --max-episodes 0 -Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memory. +Performance Tips: + - Each worker loads one model instance on its assigned GPU + - The 30B model requires ~60GB VRAM per GPU + - Use --num-workers N for N GPUs to get N× speedup + - Episodes are distributed round-robin across workers """ ) parser.add_argument( @@ -821,6 +936,27 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor choices=["bfloat16", "float16", "float32"], help="Model dtype (default: bfloat16)", ) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of parallel workers for multi-GPU processing (default: 1 for sequential). " + "Set to number of GPUs available for parallel processing.", + ) + parser.add_argument( + "--gpu-ids", + type=int, + nargs="+", + default=None, + help="Specific GPU IDs to use (e.g., --gpu-ids 0 1 2). " + "If not specified, uses GPUs 0 to num-workers-1.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size for processing multiple episodes per inference (experimental, default: 1)", + ) args = parser.parse_args() @@ -898,39 +1034,120 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor console.print("[green]All episodes already annotated![/green]") return - # Initialize annotator with subtask list - annotator = VideoAnnotator( - subtask_list=subtask_list, - model_name=args.model, - device=args.device, - torch_dtype=torch_dtype - ) + # Determine GPU IDs to use + if args.gpu_ids: + gpu_ids = args.gpu_ids + if len(gpu_ids) < args.num_workers: + console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {len(gpu_ids)} GPU IDs provided[/yellow]") + args.num_workers = len(gpu_ids) + else: + # Check available GPUs + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + if args.num_workers > num_gpus: + console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {num_gpus} GPUs available[/yellow]") + args.num_workers = min(args.num_workers, num_gpus) + gpu_ids = list(range(args.num_workers)) + else: + console.print("[yellow]Warning: CUDA not available, using CPU (num_workers will be ignored)[/yellow]") + args.num_workers = 1 + gpu_ids = [0] # Dummy value for CPU - # Annotate episodes (sequential processing) + # Annotate episodes - choose sequential or parallel mode annotations = existing_annotations.copy() - for i, ep_idx in enumerate(episode_indices): - console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]") - console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]") - console.print(f"[bold cyan]{'=' * 60}[/bold cyan]") - - result_ep_idx, annotation, error = process_single_episode( - ep_idx, - dataset.root, - dataset.meta, - video_key, - fps, - annotator, - console - ) + if args.num_workers > 1: + # ===== PARALLEL PROCESSING MODE ===== + console.print(f"\n[bold cyan]Using {args.num_workers} parallel workers on GPUs: {gpu_ids}[/bold cyan]") - if error: - console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]") - continue - elif annotation: - annotations[result_ep_idx] = annotation - display_annotation(annotation, console, result_ep_idx, fps) - save_annotations_to_dataset(dataset.root, annotations, fps) + # Split episodes across workers + episodes_per_worker = [[] for _ in range(args.num_workers)] + for i, ep_idx in enumerate(episode_indices): + worker_idx = i % args.num_workers + episodes_per_worker[worker_idx].append(ep_idx) + + # Show distribution + for worker_id, episodes in enumerate(episodes_per_worker): + console.print(f"[cyan]Worker {worker_id} (GPU {gpu_ids[worker_id]}): {len(episodes)} episodes[/cyan]") + + # Start parallel processing using ProcessPoolExecutor + console.print(f"\n[bold cyan]Starting parallel annotation...[/bold cyan]") + + with ProcessPoolExecutor(max_workers=args.num_workers) as executor: + # Submit all worker jobs + futures = [] + for worker_id in range(args.num_workers): + if not episodes_per_worker[worker_id]: + continue # Skip workers with no episodes + + future = executor.submit( + worker_process_episodes, + worker_id, + gpu_ids[worker_id], + episodes_per_worker[worker_id], + args.repo_id, + video_key, + subtask_list, + args.model, + torch_dtype, + ) + futures.append(future) + + # Collect results as they complete + for future in as_completed(futures): + try: + worker_annotations = future.result() + annotations.update(worker_annotations) + + # Save after each worker completes + save_annotations_to_dataset(dataset.root, annotations, fps) + console.print(f"[green]✓ Worker completed, saved {len(worker_annotations)} annotations[/green]") + + except Exception as e: + console.print(f"[red]✗ Worker failed: {e}[/red]") + + console.print(f"\n[bold green]Parallel processing complete! Annotated {len(annotations)} episodes[/bold green]") + + # Display all annotations + for ep_idx in sorted(annotations.keys()): + if ep_idx not in existing_annotations: # Only show newly annotated + display_annotation(annotations[ep_idx], console, ep_idx, fps) + + else: + # ===== SEQUENTIAL PROCESSING MODE ===== + console.print(f"\n[bold cyan]Using sequential processing (single GPU/CPU)[/bold cyan]") + + # Initialize annotator with subtask list + annotator = VideoAnnotator( + subtask_list=subtask_list, + model_name=args.model, + device=args.device, + torch_dtype=torch_dtype + ) + + # Process episodes sequentially + for i, ep_idx in enumerate(episode_indices): + console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]") + console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]") + console.print(f"[bold cyan]{'=' * 60}[/bold cyan]") + + result_ep_idx, annotation, error = process_single_episode( + ep_idx, + dataset.root, + dataset.meta, + video_key, + fps, + annotator, + console + ) + + if error: + console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]") + continue + elif annotation: + annotations[result_ep_idx] = annotation + display_annotation(annotation, console, result_ep_idx, fps) + save_annotations_to_dataset(dataset.root, annotations, fps) # Compute temporal proportions (key SARM insight) console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]") diff --git a/scripts/visualize_sarm_predictions.py b/scripts/visualize_sarm_predictions.py index 1944803a2..20ff651f6 100644 --- a/scripts/visualize_sarm_predictions.py +++ b/scripts/visualize_sarm_predictions.py @@ -338,7 +338,9 @@ def visualize_predictions( output_path: Path, show_frames: bool = False, num_sample_frames: int = 8, - figsize: tuple = (14, 8) + figsize: tuple = (14, 8), + subtask_names: list[str] | None = None, + temporal_proportions: dict[str, float] | None = None ): """ Create visualization of SARM predictions. @@ -352,10 +354,18 @@ def visualize_predictions( show_frames: Whether to include sample frames num_sample_frames: Number of frames to show figsize: Figure size (width, height) + subtask_names: Optional list of subtask names for labeling + temporal_proportions: Optional dict of temporal proportions for each subtask """ num_stages = stage_predictions.shape[1] stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) + # Use subtask names if available, otherwise use generic labels + if subtask_names is not None and len(subtask_names) == num_stages: + stage_labels = subtask_names + else: + stage_labels = [f'Stage {i+1}' for i in range(num_stages)] + if show_frames: # Create figure with progress plot, stage plot, and sample frames fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) @@ -398,12 +408,35 @@ def visualize_predictions( # Plot 2: Stage predictions (stacked area plot) ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)], - colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)]) + colors=stage_colors, alpha=0.8, labels=stage_labels) ax_stages.set_xlabel('Frame Index', fontsize=12) ax_stages.set_ylabel('Stage Probability', fontsize=12) ax_stages.set_ylim(0, 1) ax_stages.grid(True, alpha=0.3) - ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8) + + # Adjust legend based on number of stages and label lengths + if num_stages <= 5: + ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8) + else: + ax_stages.legend(loc='upper left', ncol=3, fontsize=7) + + # Add vertical lines and labels for expected stage transitions (if temporal proportions available) + if temporal_proportions is not None and subtask_names is not None: + cumulative_progress = 0.0 + for i, name in enumerate(stage_labels): + if name in temporal_proportions: + # Find approximate frame where this stage should end + stage_end_progress = cumulative_progress + temporal_proportions[name] + + # Find frame index closest to this progress + progress_diffs = np.abs(progress_predictions - stage_end_progress) + stage_end_frame = np.argmin(progress_diffs) + + # Draw vertical line + ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1) + ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1) + + cumulative_progress = stage_end_progress # Plot 3: Sample frames (if requested) if show_frames: @@ -431,7 +464,13 @@ def visualize_predictions( # Add frame number, progress, and stage progress_val = progress_predictions[frame_idx] stage_idx = np.argmax(stage_predictions[frame_idx]) - label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}' + stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}' + + # Truncate long stage names for display + if len(stage_name) > 15: + stage_name = stage_name[:12] + '...' + + label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}' # Draw label on image ax_frames.text(x_start + frame_width / 2, -10, label, @@ -495,6 +534,29 @@ def main(): # Run inference progress_predictions, stage_predictions = run_inference(model, frames, states, task_description) + # Extract subtask names and temporal proportions from model config if available + subtask_names = None + temporal_proportions = None + + if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None: + subtask_names = model.config.subtask_names + logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}") + + # Try to load temporal proportions from model's dataset meta + if hasattr(model, 'dataset_stats') and model.dataset_stats is not None: + if 'temporal_proportions' in model.dataset_stats: + temporal_proportions = model.dataset_stats['temporal_proportions'] + logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}") + + # Fallback: try to load from dataset meta + if temporal_proportions is None and subtask_names is not None: + import json + proportions_path = dataset.root / "meta" / "temporal_proportions.json" + if proportions_path.exists(): + with open(proportions_path, 'r') as f: + temporal_proportions = json.load(f) + logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}") + # Create visualization output_dir = Path(args.output_dir) output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" @@ -507,7 +569,9 @@ def main(): output_path, show_frames=args.show_frames, num_sample_frames=args.num_sample_frames, - figsize=tuple(args.figsize) + figsize=tuple(args.figsize), + subtask_names=subtask_names, + temporal_proportions=temporal_proportions ) # Save predictions as numpy arrays @@ -527,8 +591,30 @@ def main(): logger.info(f"Final Progress: {progress_predictions[-1]:.3f}") logger.info(f"Max Progress: {progress_predictions.max():.3f}") logger.info(f"Mean Progress: {progress_predictions.mean():.3f}") - logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}") - logger.info(f"Visualization: {output_path}") + + # Show most common stage with name if available + most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0)) + if subtask_names is not None and most_common_stage_idx < len(subtask_names): + most_common_stage_name = subtask_names[most_common_stage_idx] + logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})") + else: + logger.info(f"Most Common Stage: {most_common_stage_idx + 1}") + + # Show subtask breakdown + if subtask_names is not None: + logger.info("\nSubtask Breakdown:") + total_frames = len(stage_predictions) + for i, name in enumerate(subtask_names): + # Calculate percentage of frames where this stage was dominant + dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i) + percentage = (dominant_frames / total_frames) * 100 + logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)") + + if temporal_proportions is not None and name in temporal_proportions: + expected_pct = temporal_proportions[name] * 100 + logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%") + + logger.info(f"\nVisualization: {output_path}") logger.info("="*60) diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 67a497106..3592f050b 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -49,7 +49,8 @@ class SARMConfig(PreTrainedConfig): hidden_dim: int = 768 # Transformer hidden dimension num_heads: int = 12 # Number of attention heads num_layers: int = 8 # Number of transformer layers - num_stages: int = 5 # Number of task stages for classification + num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available) + subtask_names: list | None = None # List of subtask names (auto-populated from annotations) # Temporal parameters max_length: int = 9 # Maximum video sequence length (should match num_frames) @@ -61,6 +62,7 @@ class SARMConfig(PreTrainedConfig): clip_batch_size: int = 64 # Batch size for CLIP encoding gradient_checkpointing: bool = False # Enable gradient checkpointing dropout: float = 0.1 # Dropout rate + stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations # RA-BC (Reward-Aligned Behavior Cloning) parameters enable_rabc: bool = False # Enable RA-BC weighted loss @@ -79,6 +81,7 @@ class SARMConfig(PreTrainedConfig): task_description: str = "perform the task" # Default task description encode_on_the_fly: bool = True # Encode images/text during training use_dataset_task: bool = True # Use task descriptions from dataset + use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available # Features (required by PreTrainedPolicy) input_features: dict = field(default_factory=lambda: { diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 1dd53e07d..084db750c 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -573,6 +573,15 @@ class SARMRewardModel(PreTrainedPolicy): if state_features is not None: state_features = state_features.to(self.device) + # Extract stage labels and progress targets if available (from subtask annotations) + stage_labels = observation.get('stage_labels', None) + if stage_labels is not None: + stage_labels = stage_labels.to(self.device) + + progress_targets_from_annotations = observation.get('progress_targets', None) + if progress_targets_from_annotations is not None: + progress_targets_from_annotations = progress_targets_from_annotations.to(self.device) + batch_size = video_features.shape[0] max_length = self.config.num_frames @@ -713,16 +722,44 @@ class SARMRewardModel(PreTrainedPolicy): processed_videos, text_features, processed_states ) - # Compute progress loss using augmented targets + # Use annotation-based progress targets if available, otherwise use computed ones + if progress_targets_from_annotations is not None and len(processed_videos) == 1: + # Use refined progress from subtask annotations (single sample case) + # Ensure shapes match + if progress_targets_from_annotations.shape != progress_preds.shape: + if progress_targets_from_annotations.dim() == 2: + progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0) + progress_targets = progress_targets_from_annotations + + # Compute progress loss using targets progress_loss = F.mse_loss(progress_preds, progress_targets) - # For now, just use progress loss since we don't have stage annotations - # In future: can add stage loss when we have annotated stage labels - total_loss = progress_loss + # Compute stage loss if labels are available + stage_loss = None + if stage_labels is not None and len(processed_videos) == 1: + # Ensure stage_labels matches the sequence length + if stage_labels.dim() == 1 and stage_logits.dim() == 3: + # stage_labels: (seq_len,) -> need to expand to (batch, seq_len) + stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1) + elif stage_labels.shape[0] != stage_logits.shape[0]: + # Single label for batch - expand + stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1]) + + # Compute cross-entropy loss for stage classification + stage_loss = compute_stage_loss(stage_logits, stage_labels) - output_dict = { - 'progress_loss': progress_loss.item(), - } + # Combine losses + if stage_loss is not None: + total_loss = progress_loss + self.config.stage_loss_weight * stage_loss + output_dict = { + 'progress_loss': progress_loss.item(), + 'stage_loss': stage_loss.item(), + } + else: + total_loss = progress_loss + output_dict = { + 'progress_loss': progress_loss.item(), + } # Compute misaligned loss (following SARM paper and ReWiND) # "To improve video-language alignment, task descriptions are occasionally perturbed" diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 9c4fedc53..c3a6e915e 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -65,6 +65,12 @@ class SARMEncodingProcessorStep(ProcessorStep): self.dataset_meta = dataset_meta self.dataset_stats = dataset_stats + # Compute temporal proportions from subtask annotations if available + self.temporal_proportions = None + self.subtask_names = None + if dataset_meta is not None and config.use_subtask_annotations: + self._compute_temporal_proportions() + # Initialize encoders self._init_encoders() @@ -95,6 +101,216 @@ class SARMEncodingProcessorStep(ProcessorStep): self.device = device + def _compute_temporal_proportions(self): + """Compute temporal proportions for each subtask from dataset annotations.""" + if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'): + return + + # Check if subtask annotations exist + episodes = self.dataset_meta.episodes + if episodes is None or len(episodes) == 0: + return + + # Check for subtask_names column + if 'subtask_names' not in episodes.column_names: + logging.info("No subtask annotations found in dataset") + return + + # Convert to pandas for easier processing + import pandas as pd + episodes_df = episodes.to_pandas() + + # Collect all subtask names and compute average durations + subtask_durations = {} + subtask_counts = {} + all_subtask_names = set() + + for ep_idx in episodes_df.index: + subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] + + # Skip episodes without annotations + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + continue + + start_times = episodes_df.loc[ep_idx, 'subtask_start_times'] + end_times = episodes_df.loc[ep_idx, 'subtask_end_times'] + + # Track unique subtask names + all_subtask_names.update(subtask_names) + + # Compute durations + for i, name in enumerate(subtask_names): + duration = end_times[i] - start_times[i] + if name not in subtask_durations: + subtask_durations[name] = [] + subtask_durations[name].append(duration) + + if not all_subtask_names: + logging.info("No valid subtask annotations found") + return + + # Sort subtask names for consistent ordering + self.subtask_names = sorted(list(all_subtask_names)) + self.config.num_stages = len(self.subtask_names) + self.config.subtask_names = self.subtask_names # Store in config for reference + + # Compute average duration for each subtask + avg_durations = {} + for name in self.subtask_names: + if name in subtask_durations: + avg_durations[name] = np.mean(subtask_durations[name]) + else: + avg_durations[name] = 0.0 + + # Normalize to get proportions + total_duration = sum(avg_durations.values()) + if total_duration > 0: + self.temporal_proportions = { + name: avg_durations[name] / total_duration + for name in self.subtask_names + } + else: + # Equal proportions if no duration info + self.temporal_proportions = { + name: 1.0 / len(self.subtask_names) + for name in self.subtask_names + } + + logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}") + + def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features): + """Generate stage labels and refined progress targets from subtask annotations. + + Args: + frame_index: Current frame index or indices + episode_index: Episode index + video_features: Video features tensor to determine sequence length + + Returns: + Tuple of (stage_labels, progress_targets) or (None, None) if no annotations + """ + if self.temporal_proportions is None or episode_index is None: + return None, None + + # Convert to pandas to access annotations + import pandas as pd + episodes_df = self.dataset_meta.episodes.to_pandas() + + # Handle batch processing + is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 + + if is_batch: + # Process multiple samples - for now, return None + # (batch processing of annotations is complex and not critical) + return None, None + + # Single sample processing + if isinstance(episode_index, torch.Tensor): + ep_idx = int(episode_index.item()) + else: + ep_idx = int(episode_index) + + if isinstance(frame_index, torch.Tensor): + frame_idx = int(frame_index.item()) + else: + frame_idx = int(frame_index) + + # Get subtask annotations for this episode + if ep_idx >= len(episodes_df): + return None, None + + subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + return None, None + + subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames'] + subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames'] + + # Get episode boundaries + ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"] + ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"] + + # Determine sequence length + if video_features is not None and video_features.dim() > 0: + seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1] + else: + seq_len = 1 + + # Generate labels for each frame in the sequence + stage_labels = [] + progress_targets = [] + + # Get frame gap for temporal sampling + frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1 + + for i in range(seq_len): + # Calculate actual frame index for this position in sequence + if frame_gap > 1: + offset = -(seq_len - 1 - i) * frame_gap + current_frame = max(0, frame_idx + offset - ep_start) + else: + current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start) + + # Find which subtask this frame belongs to + stage_idx = -1 + within_subtask_progress = 0.0 + cumulative_progress = 0.0 + + for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)): + if current_frame >= start_frame and current_frame <= end_frame: + # Found the subtask + stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0 + + # Calculate within-subtask progress + subtask_duration = end_frame - start_frame + if subtask_duration > 0: + within_subtask_progress = (current_frame - start_frame) / subtask_duration + else: + within_subtask_progress = 1.0 + + # Calculate cumulative progress + for k in range(j): + prev_name = subtask_names[k] + if prev_name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[prev_name] + + # Add current subtask's partial progress + if name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[name] * within_subtask_progress + + break + + # If no matching subtask found, estimate based on position + if stage_idx == -1: + # Estimate stage based on frame position + if current_frame < subtask_start_frames[0]: + stage_idx = 0 + cumulative_progress = 0.0 + elif current_frame > subtask_end_frames[-1]: + stage_idx = len(self.subtask_names) - 1 + cumulative_progress = 1.0 + else: + # Between subtasks - use previous subtask's end state + for j in range(len(subtask_names) - 1): + if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]: + name = subtask_names[j] + stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j + # Sum up all previous subtasks + for k in range(j + 1): + prev_name = subtask_names[k] + if prev_name in self.temporal_proportions: + cumulative_progress += self.temporal_proportions[prev_name] + break + + stage_labels.append(stage_idx) + progress_targets.append(cumulative_progress) + + # Convert to tensors + stage_labels = torch.tensor(stage_labels, dtype=torch.long) + progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim + + return stage_labels, progress_targets + def __call__(self, transition: EnvTransition) -> EnvTransition: """Encode images, text, and normalize states in the transition.""" from lerobot.processor.core import TransitionKey @@ -351,6 +567,15 @@ class SARMEncodingProcessorStep(ProcessorStep): observation['episode_length'] = episode_length + # Generate stage labels and refined progress from subtask annotations + if self.temporal_proportions is not None and self.dataset_meta is not None: + stage_labels, progress_targets = self._generate_stage_and_progress_labels( + frame_index, episode_index, observation.get('video_features') + ) + if stage_labels is not None: + observation['stage_labels'] = stage_labels + observation['progress_targets'] = progress_targets + new_transition[TransitionKey.OBSERVATION] = observation return new_transition