This commit is contained in:
Pepijn
2025-11-18 15:28:40 +01:00
parent f688eb160b
commit 1ffdc6f49e
5 changed files with 624 additions and 56 deletions
+258 -41
View File
@@ -42,13 +42,30 @@ Usage:
# Install dependencies # Install dependencies
pip install transformers torch qwen-vl-utils accelerate pip install transformers torch qwen-vl-utils accelerate
# Annotate and push to hub: # Sequential processing (single GPU):
python subtask_annotation.py \\ python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\ --repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\ --subtasks "reach,grasp,lift,place" \\
--video-key observation.images.base \\ --video-key observation.images.base \\
--push-to-hub --push-to-hub
# Parallel processing (4 GPUs):
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.base \\
--num-workers 4 \\
--push-to-hub
# Parallel with specific GPU IDs:
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.base \\
--num-workers 2 \\
--gpu-ids 0 2 \\
--push-to-hub
""" """
import argparse import argparse
@@ -56,6 +73,8 @@ import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import multiprocessing as mp
import pandas as pd import pandas as pd
import torch import torch
@@ -63,7 +82,7 @@ from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.tree import Tree from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
@@ -177,7 +196,7 @@ class VideoAnnotator:
file_path: Path, file_path: Path,
start_timestamp: float, start_timestamp: float,
end_timestamp: float, end_timestamp: float,
target_fps: int = 2 target_fps: int = 1
) -> Path: ) -> Path:
""" """
Extract a specific episode segment from concatenated video. Extract a specific episode segment from concatenated video.
@@ -187,7 +206,7 @@ class VideoAnnotator:
file_path: Path to the concatenated video file file_path: Path to the concatenated video file
start_timestamp: Starting timestamp in seconds (within this video file) start_timestamp: Starting timestamp in seconds (within this video file)
end_timestamp: Ending timestamp in seconds (within this video file) end_timestamp: Ending timestamp in seconds (within this video file)
target_fps: Target FPS (default: 2 for faster processing) target_fps: Target FPS (default: 1 for faster processing)
Returns: Returns:
Path to extracted video file Path to extracted video file
@@ -683,6 +702,77 @@ def process_single_episode(
return ep_idx, None, str(e) return ep_idx, None, str(e)
def worker_process_episodes(
worker_id: int,
gpu_id: int,
episode_indices: list[int],
repo_id: str,
video_key: str,
subtask_list: list[str],
model_name: str,
torch_dtype: torch.dtype,
) -> dict[int, SubtaskAnnotation]:
"""
Worker function for parallel processing across GPUs.
Args:
worker_id: Worker ID for logging
gpu_id: GPU device ID to use
episode_indices: List of episode indices to process
repo_id: Dataset repo ID
video_key: Video key to use
subtask_list: List of subtask names
model_name: Model name to load
torch_dtype: Model dtype
Returns:
Dictionary of episode_idx -> SubtaskAnnotation
"""
# Set GPU device
device = f"cuda:{gpu_id}"
# Initialize console for this worker
console = Console()
console.print(f"[cyan]Worker {worker_id} starting on GPU {gpu_id} with {len(episode_indices)} episodes[/cyan]")
# Load dataset (this is lightweight, just metadata)
dataset = LeRobotDataset(repo_id, download_videos=False)
fps = dataset.fps
# Initialize annotator for this worker
annotator = VideoAnnotator(
subtask_list=subtask_list,
model_name=model_name,
device=device,
torch_dtype=torch_dtype
)
# Process assigned episodes
annotations = {}
for i, ep_idx in enumerate(episode_indices):
console.print(f"[cyan]Worker {worker_id} | Episode {ep_idx} ({i+1}/{len(episode_indices)})[/cyan]")
result_ep_idx, annotation, error = process_single_episode(
ep_idx,
dataset.root,
dataset.meta,
video_key,
fps,
annotator,
console
)
if error:
console.print(f"[red]Worker {worker_id} | ✗ Failed episode {result_ep_idx}: {error}[/red]")
elif annotation:
annotations[result_ep_idx] = annotation
console.print(f"[green]Worker {worker_id} | ✓ Completed episode {result_ep_idx}[/green]")
console.print(f"[bold green]Worker {worker_id} completed {len(annotations)}/{len(episode_indices)} episodes[/bold green]")
return annotations
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]: def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
""" """
Compute average temporal proportion for each subtask across all episodes. Compute average temporal proportion for each subtask across all episodes.
@@ -742,16 +832,41 @@ def main():
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
# Sequential processing (single GPU):
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.top \\
--push-to-hub
# Parallel processing with 4 GPUs (4x speedup):
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.top \\
--num-workers 4 \\
--push-to-hub
# Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3):
python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
--subtasks "reach,grasp,lift,place" \\
--video-key observation.images.top \\
--num-workers 3 \\
--gpu-ids 0 2 3 \\
--push-to-hub
# List available cameras: # List available cameras:
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --max-episodes 0 python subtask_annotation.py \\
--repo-id pepijn223/mydataset \\
# Annotate with specific camera: --subtasks "reach,grasp" \\
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --push-to-hub --max-episodes 0
# Use custom model:
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --model Qwen/Qwen3-VL-30B-A3B-Instruct --push-to-hub
Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memory. Performance Tips:
- Each worker loads one model instance on its assigned GPU
- The 30B model requires ~60GB VRAM per GPU
- Use --num-workers N for N GPUs to get N× speedup
- Episodes are distributed round-robin across workers
""" """
) )
parser.add_argument( parser.add_argument(
@@ -821,6 +936,27 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor
choices=["bfloat16", "float16", "float32"], choices=["bfloat16", "float16", "float32"],
help="Model dtype (default: bfloat16)", help="Model dtype (default: bfloat16)",
) )
parser.add_argument(
"--num-workers",
type=int,
default=1,
help="Number of parallel workers for multi-GPU processing (default: 1 for sequential). "
"Set to number of GPUs available for parallel processing.",
)
parser.add_argument(
"--gpu-ids",
type=int,
nargs="+",
default=None,
help="Specific GPU IDs to use (e.g., --gpu-ids 0 1 2). "
"If not specified, uses GPUs 0 to num-workers-1.",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for processing multiple episodes per inference (experimental, default: 1)",
)
args = parser.parse_args() args = parser.parse_args()
@@ -898,39 +1034,120 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor
console.print("[green]All episodes already annotated![/green]") console.print("[green]All episodes already annotated![/green]")
return return
# Initialize annotator with subtask list # Determine GPU IDs to use
annotator = VideoAnnotator( if args.gpu_ids:
subtask_list=subtask_list, gpu_ids = args.gpu_ids
model_name=args.model, if len(gpu_ids) < args.num_workers:
device=args.device, console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {len(gpu_ids)} GPU IDs provided[/yellow]")
torch_dtype=torch_dtype args.num_workers = len(gpu_ids)
) else:
# Check available GPUs
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
if args.num_workers > num_gpus:
console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {num_gpus} GPUs available[/yellow]")
args.num_workers = min(args.num_workers, num_gpus)
gpu_ids = list(range(args.num_workers))
else:
console.print("[yellow]Warning: CUDA not available, using CPU (num_workers will be ignored)[/yellow]")
args.num_workers = 1
gpu_ids = [0] # Dummy value for CPU
# Annotate episodes (sequential processing) # Annotate episodes - choose sequential or parallel mode
annotations = existing_annotations.copy() annotations = existing_annotations.copy()
for i, ep_idx in enumerate(episode_indices): if args.num_workers > 1:
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]") # ===== PARALLEL PROCESSING MODE =====
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]") console.print(f"\n[bold cyan]Using {args.num_workers} parallel workers on GPUs: {gpu_ids}[/bold cyan]")
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
result_ep_idx, annotation, error = process_single_episode(
ep_idx,
dataset.root,
dataset.meta,
video_key,
fps,
annotator,
console
)
if error: # Split episodes across workers
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]") episodes_per_worker = [[] for _ in range(args.num_workers)]
continue for i, ep_idx in enumerate(episode_indices):
elif annotation: worker_idx = i % args.num_workers
annotations[result_ep_idx] = annotation episodes_per_worker[worker_idx].append(ep_idx)
display_annotation(annotation, console, result_ep_idx, fps)
save_annotations_to_dataset(dataset.root, annotations, fps) # Show distribution
for worker_id, episodes in enumerate(episodes_per_worker):
console.print(f"[cyan]Worker {worker_id} (GPU {gpu_ids[worker_id]}): {len(episodes)} episodes[/cyan]")
# Start parallel processing using ProcessPoolExecutor
console.print(f"\n[bold cyan]Starting parallel annotation...[/bold cyan]")
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
# Submit all worker jobs
futures = []
for worker_id in range(args.num_workers):
if not episodes_per_worker[worker_id]:
continue # Skip workers with no episodes
future = executor.submit(
worker_process_episodes,
worker_id,
gpu_ids[worker_id],
episodes_per_worker[worker_id],
args.repo_id,
video_key,
subtask_list,
args.model,
torch_dtype,
)
futures.append(future)
# Collect results as they complete
for future in as_completed(futures):
try:
worker_annotations = future.result()
annotations.update(worker_annotations)
# Save after each worker completes
save_annotations_to_dataset(dataset.root, annotations, fps)
console.print(f"[green]✓ Worker completed, saved {len(worker_annotations)} annotations[/green]")
except Exception as e:
console.print(f"[red]✗ Worker failed: {e}[/red]")
console.print(f"\n[bold green]Parallel processing complete! Annotated {len(annotations)} episodes[/bold green]")
# Display all annotations
for ep_idx in sorted(annotations.keys()):
if ep_idx not in existing_annotations: # Only show newly annotated
display_annotation(annotations[ep_idx], console, ep_idx, fps)
else:
# ===== SEQUENTIAL PROCESSING MODE =====
console.print(f"\n[bold cyan]Using sequential processing (single GPU/CPU)[/bold cyan]")
# Initialize annotator with subtask list
annotator = VideoAnnotator(
subtask_list=subtask_list,
model_name=args.model,
device=args.device,
torch_dtype=torch_dtype
)
# Process episodes sequentially
for i, ep_idx in enumerate(episode_indices):
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]")
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]")
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
result_ep_idx, annotation, error = process_single_episode(
ep_idx,
dataset.root,
dataset.meta,
video_key,
fps,
annotator,
console
)
if error:
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]")
continue
elif annotation:
annotations[result_ep_idx] = annotation
display_annotation(annotation, console, result_ep_idx, fps)
save_annotations_to_dataset(dataset.root, annotations, fps)
# Compute temporal proportions (key SARM insight) # Compute temporal proportions (key SARM insight)
console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]") console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]")
+93 -7
View File
@@ -338,7 +338,9 @@ def visualize_predictions(
output_path: Path, output_path: Path,
show_frames: bool = False, show_frames: bool = False,
num_sample_frames: int = 8, num_sample_frames: int = 8,
figsize: tuple = (14, 8) figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None
): ):
""" """
Create visualization of SARM predictions. Create visualization of SARM predictions.
@@ -352,10 +354,18 @@ def visualize_predictions(
show_frames: Whether to include sample frames show_frames: Whether to include sample frames
num_sample_frames: Number of frames to show num_sample_frames: Number of frames to show
figsize: Figure size (width, height) figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
""" """
num_stages = stage_predictions.shape[1] num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages)) stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
if show_frames: if show_frames:
# Create figure with progress plot, stage plot, and sample frames # Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4)) fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
@@ -398,12 +408,35 @@ def visualize_predictions(
# Plot 2: Stage predictions (stacked area plot) # Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)], ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)]) colors=stage_colors, alpha=0.8, labels=stage_labels)
ax_stages.set_xlabel('Frame Index', fontsize=12) ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12) ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1) ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3) ax_stages.grid(True, alpha=0.3)
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested) # Plot 3: Sample frames (if requested)
if show_frames: if show_frames:
@@ -431,7 +464,13 @@ def visualize_predictions(
# Add frame number, progress, and stage # Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx] progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx]) stage_idx = np.argmax(stage_predictions[frame_idx])
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}' stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image # Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label, ax_frames.text(x_start + frame_width / 2, -10, label,
@@ -495,6 +534,29 @@ def main():
# Run inference # Run inference
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description) progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model's dataset meta
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
if 'temporal_proportions' in model.dataset_stats:
temporal_proportions = model.dataset_stats['temporal_proportions']
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None and subtask_names is not None:
import json
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Create visualization # Create visualization
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png" output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
@@ -507,7 +569,9 @@ def main():
output_path, output_path,
show_frames=args.show_frames, show_frames=args.show_frames,
num_sample_frames=args.num_sample_frames, num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize) figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions
) )
# Save predictions as numpy arrays # Save predictions as numpy arrays
@@ -527,8 +591,30 @@ def main():
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}") logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
logger.info(f"Max Progress: {progress_predictions.max():.3f}") logger.info(f"Max Progress: {progress_predictions.max():.3f}")
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}") logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}")
logger.info(f"Visualization: {output_path}") # Show most common stage with name if available
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
most_common_stage_name = subtask_names[most_common_stage_idx]
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
else:
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
# Show subtask breakdown
if subtask_names is not None:
logger.info("\nSubtask Breakdown:")
total_frames = len(stage_predictions)
for i, name in enumerate(subtask_names):
# Calculate percentage of frames where this stage was dominant
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
percentage = (dominant_frames / total_frames) * 100
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
if temporal_proportions is not None and name in temporal_proportions:
expected_pct = temporal_proportions[name] * 100
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
logger.info(f"\nVisualization: {output_path}")
logger.info("="*60) logger.info("="*60)
@@ -49,7 +49,8 @@ class SARMConfig(PreTrainedConfig):
hidden_dim: int = 768 # Transformer hidden dimension hidden_dim: int = 768 # Transformer hidden dimension
num_heads: int = 12 # Number of attention heads num_heads: int = 12 # Number of attention heads
num_layers: int = 8 # Number of transformer layers num_layers: int = 8 # Number of transformer layers
num_stages: int = 5 # Number of task stages for classification num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
# Temporal parameters # Temporal parameters
max_length: int = 9 # Maximum video sequence length (should match num_frames) max_length: int = 9 # Maximum video sequence length (should match num_frames)
@@ -61,6 +62,7 @@ class SARMConfig(PreTrainedConfig):
clip_batch_size: int = 64 # Batch size for CLIP encoding clip_batch_size: int = 64 # Batch size for CLIP encoding
gradient_checkpointing: bool = False # Enable gradient checkpointing gradient_checkpointing: bool = False # Enable gradient checkpointing
dropout: float = 0.1 # Dropout rate dropout: float = 0.1 # Dropout rate
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
# RA-BC (Reward-Aligned Behavior Cloning) parameters # RA-BC (Reward-Aligned Behavior Cloning) parameters
enable_rabc: bool = False # Enable RA-BC weighted loss enable_rabc: bool = False # Enable RA-BC weighted loss
@@ -79,6 +81,7 @@ class SARMConfig(PreTrainedConfig):
task_description: str = "perform the task" # Default task description task_description: str = "perform the task" # Default task description
encode_on_the_fly: bool = True # Encode images/text during training encode_on_the_fly: bool = True # Encode images/text during training
use_dataset_task: bool = True # Use task descriptions from dataset use_dataset_task: bool = True # Use task descriptions from dataset
use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available
# Features (required by PreTrainedPolicy) # Features (required by PreTrainedPolicy)
input_features: dict = field(default_factory=lambda: { input_features: dict = field(default_factory=lambda: {
+44 -7
View File
@@ -573,6 +573,15 @@ class SARMRewardModel(PreTrainedPolicy):
if state_features is not None: if state_features is not None:
state_features = state_features.to(self.device) state_features = state_features.to(self.device)
# Extract stage labels and progress targets if available (from subtask annotations)
stage_labels = observation.get('stage_labels', None)
if stage_labels is not None:
stage_labels = stage_labels.to(self.device)
progress_targets_from_annotations = observation.get('progress_targets', None)
if progress_targets_from_annotations is not None:
progress_targets_from_annotations = progress_targets_from_annotations.to(self.device)
batch_size = video_features.shape[0] batch_size = video_features.shape[0]
max_length = self.config.num_frames max_length = self.config.num_frames
@@ -713,16 +722,44 @@ class SARMRewardModel(PreTrainedPolicy):
processed_videos, text_features, processed_states processed_videos, text_features, processed_states
) )
# Compute progress loss using augmented targets # Use annotation-based progress targets if available, otherwise use computed ones
if progress_targets_from_annotations is not None and len(processed_videos) == 1:
# Use refined progress from subtask annotations (single sample case)
# Ensure shapes match
if progress_targets_from_annotations.shape != progress_preds.shape:
if progress_targets_from_annotations.dim() == 2:
progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0)
progress_targets = progress_targets_from_annotations
# Compute progress loss using targets
progress_loss = F.mse_loss(progress_preds, progress_targets) progress_loss = F.mse_loss(progress_preds, progress_targets)
# For now, just use progress loss since we don't have stage annotations # Compute stage loss if labels are available
# In future: can add stage loss when we have annotated stage labels stage_loss = None
total_loss = progress_loss if stage_labels is not None and len(processed_videos) == 1:
# Ensure stage_labels matches the sequence length
if stage_labels.dim() == 1 and stage_logits.dim() == 3:
# stage_labels: (seq_len,) -> need to expand to (batch, seq_len)
stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1)
elif stage_labels.shape[0] != stage_logits.shape[0]:
# Single label for batch - expand
stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1])
# Compute cross-entropy loss for stage classification
stage_loss = compute_stage_loss(stage_logits, stage_labels)
output_dict = { # Combine losses
'progress_loss': progress_loss.item(), if stage_loss is not None:
} total_loss = progress_loss + self.config.stage_loss_weight * stage_loss
output_dict = {
'progress_loss': progress_loss.item(),
'stage_loss': stage_loss.item(),
}
else:
total_loss = progress_loss
output_dict = {
'progress_loss': progress_loss.item(),
}
# Compute misaligned loss (following SARM paper and ReWiND) # Compute misaligned loss (following SARM paper and ReWiND)
# "To improve video-language alignment, task descriptions are occasionally perturbed" # "To improve video-language alignment, task descriptions are occasionally perturbed"
+225
View File
@@ -65,6 +65,12 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.dataset_meta = dataset_meta self.dataset_meta = dataset_meta
self.dataset_stats = dataset_stats self.dataset_stats = dataset_stats
# Compute temporal proportions from subtask annotations if available
self.temporal_proportions = None
self.subtask_names = None
if dataset_meta is not None and config.use_subtask_annotations:
self._compute_temporal_proportions()
# Initialize encoders # Initialize encoders
self._init_encoders() self._init_encoders()
@@ -95,6 +101,216 @@ class SARMEncodingProcessorStep(ProcessorStep):
self.device = device self.device = device
def _compute_temporal_proportions(self):
"""Compute temporal proportions for each subtask from dataset annotations."""
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
return
# Check if subtask annotations exist
episodes = self.dataset_meta.episodes
if episodes is None or len(episodes) == 0:
return
# Check for subtask_names column
if 'subtask_names' not in episodes.column_names:
logging.info("No subtask annotations found in dataset")
return
# Convert to pandas for easier processing
import pandas as pd
episodes_df = episodes.to_pandas()
# Collect all subtask names and compute average durations
subtask_durations = {}
subtask_counts = {}
all_subtask_names = set()
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
# Track unique subtask names
all_subtask_names.update(subtask_names)
# Compute durations
for i, name in enumerate(subtask_names):
duration = end_times[i] - start_times[i]
if name not in subtask_durations:
subtask_durations[name] = []
subtask_durations[name].append(duration)
if not all_subtask_names:
logging.info("No valid subtask annotations found")
return
# Sort subtask names for consistent ordering
self.subtask_names = sorted(list(all_subtask_names))
self.config.num_stages = len(self.subtask_names)
self.config.subtask_names = self.subtask_names # Store in config for reference
# Compute average duration for each subtask
avg_durations = {}
for name in self.subtask_names:
if name in subtask_durations:
avg_durations[name] = np.mean(subtask_durations[name])
else:
avg_durations[name] = 0.0
# Normalize to get proportions
total_duration = sum(avg_durations.values())
if total_duration > 0:
self.temporal_proportions = {
name: avg_durations[name] / total_duration
for name in self.subtask_names
}
else:
# Equal proportions if no duration info
self.temporal_proportions = {
name: 1.0 / len(self.subtask_names)
for name in self.subtask_names
}
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
"""Generate stage labels and refined progress targets from subtask annotations.
Args:
frame_index: Current frame index or indices
episode_index: Episode index
video_features: Video features tensor to determine sequence length
Returns:
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations
"""
if self.temporal_proportions is None or episode_index is None:
return None, None
# Convert to pandas to access annotations
import pandas as pd
episodes_df = self.dataset_meta.episodes.to_pandas()
# Handle batch processing
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
if is_batch:
# Process multiple samples - for now, return None
# (batch processing of annotations is complex and not critical)
return None, None
# Single sample processing
if isinstance(episode_index, torch.Tensor):
ep_idx = int(episode_index.item())
else:
ep_idx = int(episode_index)
if isinstance(frame_index, torch.Tensor):
frame_idx = int(frame_index.item())
else:
frame_idx = int(frame_index)
# Get subtask annotations for this episode
if ep_idx >= len(episodes_df):
return None, None
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
return None, None
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
# Get episode boundaries
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
# Determine sequence length
if video_features is not None and video_features.dim() > 0:
seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1]
else:
seq_len = 1
# Generate labels for each frame in the sequence
stage_labels = []
progress_targets = []
# Get frame gap for temporal sampling
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
for i in range(seq_len):
# Calculate actual frame index for this position in sequence
if frame_gap > 1:
offset = -(seq_len - 1 - i) * frame_gap
current_frame = max(0, frame_idx + offset - ep_start)
else:
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
# Find which subtask this frame belongs to
stage_idx = -1
within_subtask_progress = 0.0
cumulative_progress = 0.0
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
if current_frame >= start_frame and current_frame <= end_frame:
# Found the subtask
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
# Calculate within-subtask progress
subtask_duration = end_frame - start_frame
if subtask_duration > 0:
within_subtask_progress = (current_frame - start_frame) / subtask_duration
else:
within_subtask_progress = 1.0
# Calculate cumulative progress
for k in range(j):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
# Add current subtask's partial progress
if name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
break
# If no matching subtask found, estimate based on position
if stage_idx == -1:
# Estimate stage based on frame position
if current_frame < subtask_start_frames[0]:
stage_idx = 0
cumulative_progress = 0.0
elif current_frame > subtask_end_frames[-1]:
stage_idx = len(self.subtask_names) - 1
cumulative_progress = 1.0
else:
# Between subtasks - use previous subtask's end state
for j in range(len(subtask_names) - 1):
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
name = subtask_names[j]
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
# Sum up all previous subtasks
for k in range(j + 1):
prev_name = subtask_names[k]
if prev_name in self.temporal_proportions:
cumulative_progress += self.temporal_proportions[prev_name]
break
stage_labels.append(stage_idx)
progress_targets.append(cumulative_progress)
# Convert to tensors
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim
return stage_labels, progress_targets
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition.""" """Encode images, text, and normalize states in the transition."""
from lerobot.processor.core import TransitionKey from lerobot.processor.core import TransitionKey
@@ -351,6 +567,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['episode_length'] = episode_length observation['episode_length'] = episode_length
# Generate stage labels and refined progress from subtask annotations
if self.temporal_proportions is not None and self.dataset_meta is not None:
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
frame_index, episode_index, observation.get('video_features')
)
if stage_labels is not None:
observation['stage_labels'] = stage_labels
observation['progress_targets'] = progress_targets
new_transition[TransitionKey.OBSERVATION] = observation new_transition[TransitionKey.OBSERVATION] = observation
return new_transition return new_transition