mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
subtasks
This commit is contained in:
@@ -42,13 +42,30 @@ Usage:
|
|||||||
# Install dependencies
|
# Install dependencies
|
||||||
pip install transformers torch qwen-vl-utils accelerate
|
pip install transformers torch qwen-vl-utils accelerate
|
||||||
|
|
||||||
# Annotate and push to hub:
|
# Sequential processing (single GPU):
|
||||||
python subtask_annotation.py \\
|
python subtask_annotation.py \\
|
||||||
--repo-id pepijn223/mydataset \\
|
--repo-id pepijn223/mydataset \\
|
||||||
--subtasks "reach,grasp,lift,place" \\
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
--video-key observation.images.base \\
|
--video-key observation.images.base \\
|
||||||
--push-to-hub
|
--push-to-hub
|
||||||
|
|
||||||
|
# Parallel processing (4 GPUs):
|
||||||
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
|
--video-key observation.images.base \\
|
||||||
|
--num-workers 4 \\
|
||||||
|
--push-to-hub
|
||||||
|
|
||||||
|
# Parallel with specific GPU IDs:
|
||||||
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
|
--video-key observation.images.base \\
|
||||||
|
--num-workers 2 \\
|
||||||
|
--gpu-ids 0 2 \\
|
||||||
|
--push-to-hub
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -56,6 +73,8 @@ import json
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -63,7 +82,7 @@ from pydantic import BaseModel, Field
|
|||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
|
||||||
from rich.tree import Tree
|
from rich.tree import Tree
|
||||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
@@ -177,7 +196,7 @@ class VideoAnnotator:
|
|||||||
file_path: Path,
|
file_path: Path,
|
||||||
start_timestamp: float,
|
start_timestamp: float,
|
||||||
end_timestamp: float,
|
end_timestamp: float,
|
||||||
target_fps: int = 2
|
target_fps: int = 1
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Extract a specific episode segment from concatenated video.
|
Extract a specific episode segment from concatenated video.
|
||||||
@@ -187,7 +206,7 @@ class VideoAnnotator:
|
|||||||
file_path: Path to the concatenated video file
|
file_path: Path to the concatenated video file
|
||||||
start_timestamp: Starting timestamp in seconds (within this video file)
|
start_timestamp: Starting timestamp in seconds (within this video file)
|
||||||
end_timestamp: Ending timestamp in seconds (within this video file)
|
end_timestamp: Ending timestamp in seconds (within this video file)
|
||||||
target_fps: Target FPS (default: 2 for faster processing)
|
target_fps: Target FPS (default: 1 for faster processing)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to extracted video file
|
Path to extracted video file
|
||||||
@@ -683,6 +702,77 @@ def process_single_episode(
|
|||||||
return ep_idx, None, str(e)
|
return ep_idx, None, str(e)
|
||||||
|
|
||||||
|
|
||||||
|
def worker_process_episodes(
|
||||||
|
worker_id: int,
|
||||||
|
gpu_id: int,
|
||||||
|
episode_indices: list[int],
|
||||||
|
repo_id: str,
|
||||||
|
video_key: str,
|
||||||
|
subtask_list: list[str],
|
||||||
|
model_name: str,
|
||||||
|
torch_dtype: torch.dtype,
|
||||||
|
) -> dict[int, SubtaskAnnotation]:
|
||||||
|
"""
|
||||||
|
Worker function for parallel processing across GPUs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker_id: Worker ID for logging
|
||||||
|
gpu_id: GPU device ID to use
|
||||||
|
episode_indices: List of episode indices to process
|
||||||
|
repo_id: Dataset repo ID
|
||||||
|
video_key: Video key to use
|
||||||
|
subtask_list: List of subtask names
|
||||||
|
model_name: Model name to load
|
||||||
|
torch_dtype: Model dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of episode_idx -> SubtaskAnnotation
|
||||||
|
"""
|
||||||
|
# Set GPU device
|
||||||
|
device = f"cuda:{gpu_id}"
|
||||||
|
|
||||||
|
# Initialize console for this worker
|
||||||
|
console = Console()
|
||||||
|
console.print(f"[cyan]Worker {worker_id} starting on GPU {gpu_id} with {len(episode_indices)} episodes[/cyan]")
|
||||||
|
|
||||||
|
# Load dataset (this is lightweight, just metadata)
|
||||||
|
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||||
|
fps = dataset.fps
|
||||||
|
|
||||||
|
# Initialize annotator for this worker
|
||||||
|
annotator = VideoAnnotator(
|
||||||
|
subtask_list=subtask_list,
|
||||||
|
model_name=model_name,
|
||||||
|
device=device,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process assigned episodes
|
||||||
|
annotations = {}
|
||||||
|
|
||||||
|
for i, ep_idx in enumerate(episode_indices):
|
||||||
|
console.print(f"[cyan]Worker {worker_id} | Episode {ep_idx} ({i+1}/{len(episode_indices)})[/cyan]")
|
||||||
|
|
||||||
|
result_ep_idx, annotation, error = process_single_episode(
|
||||||
|
ep_idx,
|
||||||
|
dataset.root,
|
||||||
|
dataset.meta,
|
||||||
|
video_key,
|
||||||
|
fps,
|
||||||
|
annotator,
|
||||||
|
console
|
||||||
|
)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
console.print(f"[red]Worker {worker_id} | ✗ Failed episode {result_ep_idx}: {error}[/red]")
|
||||||
|
elif annotation:
|
||||||
|
annotations[result_ep_idx] = annotation
|
||||||
|
console.print(f"[green]Worker {worker_id} | ✓ Completed episode {result_ep_idx}[/green]")
|
||||||
|
|
||||||
|
console.print(f"[bold green]Worker {worker_id} completed {len(annotations)}/{len(episode_indices)} episodes[/bold green]")
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Compute average temporal proportion for each subtask across all episodes.
|
Compute average temporal proportion for each subtask across all episodes.
|
||||||
@@ -742,16 +832,41 @@ def main():
|
|||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
|
# Sequential processing (single GPU):
|
||||||
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
|
--video-key observation.images.top \\
|
||||||
|
--push-to-hub
|
||||||
|
|
||||||
|
# Parallel processing with 4 GPUs (4x speedup):
|
||||||
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
|
--video-key observation.images.top \\
|
||||||
|
--num-workers 4 \\
|
||||||
|
--push-to-hub
|
||||||
|
|
||||||
|
# Parallel with specific GPU IDs (e.g., GPUs 0, 2, 3):
|
||||||
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
|
--subtasks "reach,grasp,lift,place" \\
|
||||||
|
--video-key observation.images.top \\
|
||||||
|
--num-workers 3 \\
|
||||||
|
--gpu-ids 0 2 3 \\
|
||||||
|
--push-to-hub
|
||||||
|
|
||||||
# List available cameras:
|
# List available cameras:
|
||||||
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --max-episodes 0
|
python subtask_annotation.py \\
|
||||||
|
--repo-id pepijn223/mydataset \\
|
||||||
# Annotate with specific camera:
|
--subtasks "reach,grasp" \\
|
||||||
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --push-to-hub
|
--max-episodes 0
|
||||||
|
|
||||||
# Use custom model:
|
|
||||||
python subtask_annotation.py --repo-id pepijn223/mydataset --subtasks "reach,grasp" --video-key observation.images.top --model Qwen/Qwen3-VL-30B-A3B-Instruct --push-to-hub
|
|
||||||
|
|
||||||
Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memory.
|
Performance Tips:
|
||||||
|
- Each worker loads one model instance on its assigned GPU
|
||||||
|
- The 30B model requires ~60GB VRAM per GPU
|
||||||
|
- Use --num-workers N for N GPUs to get N× speedup
|
||||||
|
- Episodes are distributed round-robin across workers
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -821,6 +936,27 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor
|
|||||||
choices=["bfloat16", "float16", "float32"],
|
choices=["bfloat16", "float16", "float32"],
|
||||||
help="Model dtype (default: bfloat16)",
|
help="Model dtype (default: bfloat16)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of parallel workers for multi-GPU processing (default: 1 for sequential). "
|
||||||
|
"Set to number of GPUs available for parallel processing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu-ids",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Specific GPU IDs to use (e.g., --gpu-ids 0 1 2). "
|
||||||
|
"If not specified, uses GPUs 0 to num-workers-1.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Batch size for processing multiple episodes per inference (experimental, default: 1)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -898,39 +1034,120 @@ Note: The 30B model requires ~60GB VRAM. Make sure you have sufficient GPU memor
|
|||||||
console.print("[green]All episodes already annotated![/green]")
|
console.print("[green]All episodes already annotated![/green]")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize annotator with subtask list
|
# Determine GPU IDs to use
|
||||||
annotator = VideoAnnotator(
|
if args.gpu_ids:
|
||||||
subtask_list=subtask_list,
|
gpu_ids = args.gpu_ids
|
||||||
model_name=args.model,
|
if len(gpu_ids) < args.num_workers:
|
||||||
device=args.device,
|
console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {len(gpu_ids)} GPU IDs provided[/yellow]")
|
||||||
torch_dtype=torch_dtype
|
args.num_workers = len(gpu_ids)
|
||||||
)
|
else:
|
||||||
|
# Check available GPUs
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if args.num_workers > num_gpus:
|
||||||
|
console.print(f"[yellow]Warning: {args.num_workers} workers requested but only {num_gpus} GPUs available[/yellow]")
|
||||||
|
args.num_workers = min(args.num_workers, num_gpus)
|
||||||
|
gpu_ids = list(range(args.num_workers))
|
||||||
|
else:
|
||||||
|
console.print("[yellow]Warning: CUDA not available, using CPU (num_workers will be ignored)[/yellow]")
|
||||||
|
args.num_workers = 1
|
||||||
|
gpu_ids = [0] # Dummy value for CPU
|
||||||
|
|
||||||
# Annotate episodes (sequential processing)
|
# Annotate episodes - choose sequential or parallel mode
|
||||||
annotations = existing_annotations.copy()
|
annotations = existing_annotations.copy()
|
||||||
|
|
||||||
for i, ep_idx in enumerate(episode_indices):
|
if args.num_workers > 1:
|
||||||
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]")
|
# ===== PARALLEL PROCESSING MODE =====
|
||||||
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]")
|
console.print(f"\n[bold cyan]Using {args.num_workers} parallel workers on GPUs: {gpu_ids}[/bold cyan]")
|
||||||
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
|
|
||||||
|
|
||||||
result_ep_idx, annotation, error = process_single_episode(
|
|
||||||
ep_idx,
|
|
||||||
dataset.root,
|
|
||||||
dataset.meta,
|
|
||||||
video_key,
|
|
||||||
fps,
|
|
||||||
annotator,
|
|
||||||
console
|
|
||||||
)
|
|
||||||
|
|
||||||
if error:
|
# Split episodes across workers
|
||||||
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]")
|
episodes_per_worker = [[] for _ in range(args.num_workers)]
|
||||||
continue
|
for i, ep_idx in enumerate(episode_indices):
|
||||||
elif annotation:
|
worker_idx = i % args.num_workers
|
||||||
annotations[result_ep_idx] = annotation
|
episodes_per_worker[worker_idx].append(ep_idx)
|
||||||
display_annotation(annotation, console, result_ep_idx, fps)
|
|
||||||
save_annotations_to_dataset(dataset.root, annotations, fps)
|
# Show distribution
|
||||||
|
for worker_id, episodes in enumerate(episodes_per_worker):
|
||||||
|
console.print(f"[cyan]Worker {worker_id} (GPU {gpu_ids[worker_id]}): {len(episodes)} episodes[/cyan]")
|
||||||
|
|
||||||
|
# Start parallel processing using ProcessPoolExecutor
|
||||||
|
console.print(f"\n[bold cyan]Starting parallel annotation...[/bold cyan]")
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
|
||||||
|
# Submit all worker jobs
|
||||||
|
futures = []
|
||||||
|
for worker_id in range(args.num_workers):
|
||||||
|
if not episodes_per_worker[worker_id]:
|
||||||
|
continue # Skip workers with no episodes
|
||||||
|
|
||||||
|
future = executor.submit(
|
||||||
|
worker_process_episodes,
|
||||||
|
worker_id,
|
||||||
|
gpu_ids[worker_id],
|
||||||
|
episodes_per_worker[worker_id],
|
||||||
|
args.repo_id,
|
||||||
|
video_key,
|
||||||
|
subtask_list,
|
||||||
|
args.model,
|
||||||
|
torch_dtype,
|
||||||
|
)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for future in as_completed(futures):
|
||||||
|
try:
|
||||||
|
worker_annotations = future.result()
|
||||||
|
annotations.update(worker_annotations)
|
||||||
|
|
||||||
|
# Save after each worker completes
|
||||||
|
save_annotations_to_dataset(dataset.root, annotations, fps)
|
||||||
|
console.print(f"[green]✓ Worker completed, saved {len(worker_annotations)} annotations[/green]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Worker failed: {e}[/red]")
|
||||||
|
|
||||||
|
console.print(f"\n[bold green]Parallel processing complete! Annotated {len(annotations)} episodes[/bold green]")
|
||||||
|
|
||||||
|
# Display all annotations
|
||||||
|
for ep_idx in sorted(annotations.keys()):
|
||||||
|
if ep_idx not in existing_annotations: # Only show newly annotated
|
||||||
|
display_annotation(annotations[ep_idx], console, ep_idx, fps)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# ===== SEQUENTIAL PROCESSING MODE =====
|
||||||
|
console.print(f"\n[bold cyan]Using sequential processing (single GPU/CPU)[/bold cyan]")
|
||||||
|
|
||||||
|
# Initialize annotator with subtask list
|
||||||
|
annotator = VideoAnnotator(
|
||||||
|
subtask_list=subtask_list,
|
||||||
|
model_name=args.model,
|
||||||
|
device=args.device,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process episodes sequentially
|
||||||
|
for i, ep_idx in enumerate(episode_indices):
|
||||||
|
console.print(f"\n[bold cyan]{'=' * 60}[/bold cyan]")
|
||||||
|
console.print(f"[bold cyan]Episode {ep_idx} ({i + 1}/{len(episode_indices)})[/bold cyan]")
|
||||||
|
console.print(f"[bold cyan]{'=' * 60}[/bold cyan]")
|
||||||
|
|
||||||
|
result_ep_idx, annotation, error = process_single_episode(
|
||||||
|
ep_idx,
|
||||||
|
dataset.root,
|
||||||
|
dataset.meta,
|
||||||
|
video_key,
|
||||||
|
fps,
|
||||||
|
annotator,
|
||||||
|
console
|
||||||
|
)
|
||||||
|
|
||||||
|
if error:
|
||||||
|
console.print(f"[red]✗ Failed to annotate episode {result_ep_idx}: {error}[/red]")
|
||||||
|
continue
|
||||||
|
elif annotation:
|
||||||
|
annotations[result_ep_idx] = annotation
|
||||||
|
display_annotation(annotation, console, result_ep_idx, fps)
|
||||||
|
save_annotations_to_dataset(dataset.root, annotations, fps)
|
||||||
|
|
||||||
# Compute temporal proportions (key SARM insight)
|
# Compute temporal proportions (key SARM insight)
|
||||||
console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]")
|
console.print(f"\n[bold cyan]Computing Temporal Proportions[/bold cyan]")
|
||||||
|
|||||||
@@ -338,7 +338,9 @@ def visualize_predictions(
|
|||||||
output_path: Path,
|
output_path: Path,
|
||||||
show_frames: bool = False,
|
show_frames: bool = False,
|
||||||
num_sample_frames: int = 8,
|
num_sample_frames: int = 8,
|
||||||
figsize: tuple = (14, 8)
|
figsize: tuple = (14, 8),
|
||||||
|
subtask_names: list[str] | None = None,
|
||||||
|
temporal_proportions: dict[str, float] | None = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create visualization of SARM predictions.
|
Create visualization of SARM predictions.
|
||||||
@@ -352,10 +354,18 @@ def visualize_predictions(
|
|||||||
show_frames: Whether to include sample frames
|
show_frames: Whether to include sample frames
|
||||||
num_sample_frames: Number of frames to show
|
num_sample_frames: Number of frames to show
|
||||||
figsize: Figure size (width, height)
|
figsize: Figure size (width, height)
|
||||||
|
subtask_names: Optional list of subtask names for labeling
|
||||||
|
temporal_proportions: Optional dict of temporal proportions for each subtask
|
||||||
"""
|
"""
|
||||||
num_stages = stage_predictions.shape[1]
|
num_stages = stage_predictions.shape[1]
|
||||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||||
|
|
||||||
|
# Use subtask names if available, otherwise use generic labels
|
||||||
|
if subtask_names is not None and len(subtask_names) == num_stages:
|
||||||
|
stage_labels = subtask_names
|
||||||
|
else:
|
||||||
|
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
||||||
|
|
||||||
if show_frames:
|
if show_frames:
|
||||||
# Create figure with progress plot, stage plot, and sample frames
|
# Create figure with progress plot, stage plot, and sample frames
|
||||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||||
@@ -398,12 +408,35 @@ def visualize_predictions(
|
|||||||
|
|
||||||
# Plot 2: Stage predictions (stacked area plot)
|
# Plot 2: Stage predictions (stacked area plot)
|
||||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||||
colors=stage_colors, alpha=0.8, labels=[f'Stage {i+1}' for i in range(num_stages)])
|
colors=stage_colors, alpha=0.8, labels=stage_labels)
|
||||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||||
ax_stages.set_ylim(0, 1)
|
ax_stages.set_ylim(0, 1)
|
||||||
ax_stages.grid(True, alpha=0.3)
|
ax_stages.grid(True, alpha=0.3)
|
||||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
|
||||||
|
# Adjust legend based on number of stages and label lengths
|
||||||
|
if num_stages <= 5:
|
||||||
|
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||||
|
else:
|
||||||
|
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
|
||||||
|
|
||||||
|
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
|
||||||
|
if temporal_proportions is not None and subtask_names is not None:
|
||||||
|
cumulative_progress = 0.0
|
||||||
|
for i, name in enumerate(stage_labels):
|
||||||
|
if name in temporal_proportions:
|
||||||
|
# Find approximate frame where this stage should end
|
||||||
|
stage_end_progress = cumulative_progress + temporal_proportions[name]
|
||||||
|
|
||||||
|
# Find frame index closest to this progress
|
||||||
|
progress_diffs = np.abs(progress_predictions - stage_end_progress)
|
||||||
|
stage_end_frame = np.argmin(progress_diffs)
|
||||||
|
|
||||||
|
# Draw vertical line
|
||||||
|
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||||
|
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||||
|
|
||||||
|
cumulative_progress = stage_end_progress
|
||||||
|
|
||||||
# Plot 3: Sample frames (if requested)
|
# Plot 3: Sample frames (if requested)
|
||||||
if show_frames:
|
if show_frames:
|
||||||
@@ -431,7 +464,13 @@ def visualize_predictions(
|
|||||||
# Add frame number, progress, and stage
|
# Add frame number, progress, and stage
|
||||||
progress_val = progress_predictions[frame_idx]
|
progress_val = progress_predictions[frame_idx]
|
||||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\nStage: {stage_idx+1}'
|
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||||
|
|
||||||
|
# Truncate long stage names for display
|
||||||
|
if len(stage_name) > 15:
|
||||||
|
stage_name = stage_name[:12] + '...'
|
||||||
|
|
||||||
|
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||||
|
|
||||||
# Draw label on image
|
# Draw label on image
|
||||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||||
@@ -495,6 +534,29 @@ def main():
|
|||||||
# Run inference
|
# Run inference
|
||||||
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
|
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
|
||||||
|
|
||||||
|
# Extract subtask names and temporal proportions from model config if available
|
||||||
|
subtask_names = None
|
||||||
|
temporal_proportions = None
|
||||||
|
|
||||||
|
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
|
||||||
|
subtask_names = model.config.subtask_names
|
||||||
|
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
|
||||||
|
|
||||||
|
# Try to load temporal proportions from model's dataset meta
|
||||||
|
if hasattr(model, 'dataset_stats') and model.dataset_stats is not None:
|
||||||
|
if 'temporal_proportions' in model.dataset_stats:
|
||||||
|
temporal_proportions = model.dataset_stats['temporal_proportions']
|
||||||
|
logger.info(f"✓ Found temporal proportions in model: {temporal_proportions}")
|
||||||
|
|
||||||
|
# Fallback: try to load from dataset meta
|
||||||
|
if temporal_proportions is None and subtask_names is not None:
|
||||||
|
import json
|
||||||
|
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||||
|
if proportions_path.exists():
|
||||||
|
with open(proportions_path, 'r') as f:
|
||||||
|
temporal_proportions = json.load(f)
|
||||||
|
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||||
|
|
||||||
# Create visualization
|
# Create visualization
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||||
@@ -507,7 +569,9 @@ def main():
|
|||||||
output_path,
|
output_path,
|
||||||
show_frames=args.show_frames,
|
show_frames=args.show_frames,
|
||||||
num_sample_frames=args.num_sample_frames,
|
num_sample_frames=args.num_sample_frames,
|
||||||
figsize=tuple(args.figsize)
|
figsize=tuple(args.figsize),
|
||||||
|
subtask_names=subtask_names,
|
||||||
|
temporal_proportions=temporal_proportions
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save predictions as numpy arrays
|
# Save predictions as numpy arrays
|
||||||
@@ -527,8 +591,30 @@ def main():
|
|||||||
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
|
logger.info(f"Final Progress: {progress_predictions[-1]:.3f}")
|
||||||
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
|
logger.info(f"Max Progress: {progress_predictions.max():.3f}")
|
||||||
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
|
logger.info(f"Mean Progress: {progress_predictions.mean():.3f}")
|
||||||
logger.info(f"Most Common Stage: {np.argmax(np.sum(stage_predictions, axis=0)) + 1}")
|
|
||||||
logger.info(f"Visualization: {output_path}")
|
# Show most common stage with name if available
|
||||||
|
most_common_stage_idx = np.argmax(np.sum(stage_predictions, axis=0))
|
||||||
|
if subtask_names is not None and most_common_stage_idx < len(subtask_names):
|
||||||
|
most_common_stage_name = subtask_names[most_common_stage_idx]
|
||||||
|
logger.info(f"Most Common Stage: {most_common_stage_name} (Stage {most_common_stage_idx + 1})")
|
||||||
|
else:
|
||||||
|
logger.info(f"Most Common Stage: {most_common_stage_idx + 1}")
|
||||||
|
|
||||||
|
# Show subtask breakdown
|
||||||
|
if subtask_names is not None:
|
||||||
|
logger.info("\nSubtask Breakdown:")
|
||||||
|
total_frames = len(stage_predictions)
|
||||||
|
for i, name in enumerate(subtask_names):
|
||||||
|
# Calculate percentage of frames where this stage was dominant
|
||||||
|
dominant_frames = np.sum(np.argmax(stage_predictions, axis=1) == i)
|
||||||
|
percentage = (dominant_frames / total_frames) * 100
|
||||||
|
logger.info(f" {name}: {dominant_frames}/{total_frames} frames ({percentage:.1f}%)")
|
||||||
|
|
||||||
|
if temporal_proportions is not None and name in temporal_proportions:
|
||||||
|
expected_pct = temporal_proportions[name] * 100
|
||||||
|
logger.info(f" Expected: {expected_pct:.1f}% | Actual: {percentage:.1f}%")
|
||||||
|
|
||||||
|
logger.info(f"\nVisualization: {output_path}")
|
||||||
logger.info("="*60)
|
logger.info("="*60)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
hidden_dim: int = 768 # Transformer hidden dimension
|
hidden_dim: int = 768 # Transformer hidden dimension
|
||||||
num_heads: int = 12 # Number of attention heads
|
num_heads: int = 12 # Number of attention heads
|
||||||
num_layers: int = 8 # Number of transformer layers
|
num_layers: int = 8 # Number of transformer layers
|
||||||
num_stages: int = 5 # Number of task stages for classification
|
num_stages: int = 5 # Number of task stages for classification (auto-updated from annotations if available)
|
||||||
|
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||||
|
|
||||||
# Temporal parameters
|
# Temporal parameters
|
||||||
max_length: int = 9 # Maximum video sequence length (should match num_frames)
|
max_length: int = 9 # Maximum video sequence length (should match num_frames)
|
||||||
@@ -61,6 +62,7 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||||
gradient_checkpointing: bool = False # Enable gradient checkpointing
|
gradient_checkpointing: bool = False # Enable gradient checkpointing
|
||||||
dropout: float = 0.1 # Dropout rate
|
dropout: float = 0.1 # Dropout rate
|
||||||
|
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||||
|
|
||||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||||
enable_rabc: bool = False # Enable RA-BC weighted loss
|
enable_rabc: bool = False # Enable RA-BC weighted loss
|
||||||
@@ -79,6 +81,7 @@ class SARMConfig(PreTrainedConfig):
|
|||||||
task_description: str = "perform the task" # Default task description
|
task_description: str = "perform the task" # Default task description
|
||||||
encode_on_the_fly: bool = True # Encode images/text during training
|
encode_on_the_fly: bool = True # Encode images/text during training
|
||||||
use_dataset_task: bool = True # Use task descriptions from dataset
|
use_dataset_task: bool = True # Use task descriptions from dataset
|
||||||
|
use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available
|
||||||
|
|
||||||
# Features (required by PreTrainedPolicy)
|
# Features (required by PreTrainedPolicy)
|
||||||
input_features: dict = field(default_factory=lambda: {
|
input_features: dict = field(default_factory=lambda: {
|
||||||
|
|||||||
@@ -573,6 +573,15 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
if state_features is not None:
|
if state_features is not None:
|
||||||
state_features = state_features.to(self.device)
|
state_features = state_features.to(self.device)
|
||||||
|
|
||||||
|
# Extract stage labels and progress targets if available (from subtask annotations)
|
||||||
|
stage_labels = observation.get('stage_labels', None)
|
||||||
|
if stage_labels is not None:
|
||||||
|
stage_labels = stage_labels.to(self.device)
|
||||||
|
|
||||||
|
progress_targets_from_annotations = observation.get('progress_targets', None)
|
||||||
|
if progress_targets_from_annotations is not None:
|
||||||
|
progress_targets_from_annotations = progress_targets_from_annotations.to(self.device)
|
||||||
|
|
||||||
batch_size = video_features.shape[0]
|
batch_size = video_features.shape[0]
|
||||||
max_length = self.config.num_frames
|
max_length = self.config.num_frames
|
||||||
|
|
||||||
@@ -713,16 +722,44 @@ class SARMRewardModel(PreTrainedPolicy):
|
|||||||
processed_videos, text_features, processed_states
|
processed_videos, text_features, processed_states
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute progress loss using augmented targets
|
# Use annotation-based progress targets if available, otherwise use computed ones
|
||||||
|
if progress_targets_from_annotations is not None and len(processed_videos) == 1:
|
||||||
|
# Use refined progress from subtask annotations (single sample case)
|
||||||
|
# Ensure shapes match
|
||||||
|
if progress_targets_from_annotations.shape != progress_preds.shape:
|
||||||
|
if progress_targets_from_annotations.dim() == 2:
|
||||||
|
progress_targets_from_annotations = progress_targets_from_annotations.unsqueeze(0)
|
||||||
|
progress_targets = progress_targets_from_annotations
|
||||||
|
|
||||||
|
# Compute progress loss using targets
|
||||||
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
||||||
|
|
||||||
# For now, just use progress loss since we don't have stage annotations
|
# Compute stage loss if labels are available
|
||||||
# In future: can add stage loss when we have annotated stage labels
|
stage_loss = None
|
||||||
total_loss = progress_loss
|
if stage_labels is not None and len(processed_videos) == 1:
|
||||||
|
# Ensure stage_labels matches the sequence length
|
||||||
|
if stage_labels.dim() == 1 and stage_logits.dim() == 3:
|
||||||
|
# stage_labels: (seq_len,) -> need to expand to (batch, seq_len)
|
||||||
|
stage_labels = stage_labels.unsqueeze(0).expand(stage_logits.shape[0], -1)
|
||||||
|
elif stage_labels.shape[0] != stage_logits.shape[0]:
|
||||||
|
# Single label for batch - expand
|
||||||
|
stage_labels = stage_labels.expand(stage_logits.shape[0], stage_logits.shape[1])
|
||||||
|
|
||||||
|
# Compute cross-entropy loss for stage classification
|
||||||
|
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||||
|
|
||||||
output_dict = {
|
# Combine losses
|
||||||
'progress_loss': progress_loss.item(),
|
if stage_loss is not None:
|
||||||
}
|
total_loss = progress_loss + self.config.stage_loss_weight * stage_loss
|
||||||
|
output_dict = {
|
||||||
|
'progress_loss': progress_loss.item(),
|
||||||
|
'stage_loss': stage_loss.item(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
total_loss = progress_loss
|
||||||
|
output_dict = {
|
||||||
|
'progress_loss': progress_loss.item(),
|
||||||
|
}
|
||||||
|
|
||||||
# Compute misaligned loss (following SARM paper and ReWiND)
|
# Compute misaligned loss (following SARM paper and ReWiND)
|
||||||
# "To improve video-language alignment, task descriptions are occasionally perturbed"
|
# "To improve video-language alignment, task descriptions are occasionally perturbed"
|
||||||
|
|||||||
@@ -65,6 +65,12 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
self.dataset_meta = dataset_meta
|
self.dataset_meta = dataset_meta
|
||||||
self.dataset_stats = dataset_stats
|
self.dataset_stats = dataset_stats
|
||||||
|
|
||||||
|
# Compute temporal proportions from subtask annotations if available
|
||||||
|
self.temporal_proportions = None
|
||||||
|
self.subtask_names = None
|
||||||
|
if dataset_meta is not None and config.use_subtask_annotations:
|
||||||
|
self._compute_temporal_proportions()
|
||||||
|
|
||||||
# Initialize encoders
|
# Initialize encoders
|
||||||
self._init_encoders()
|
self._init_encoders()
|
||||||
|
|
||||||
@@ -95,6 +101,216 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
def _compute_temporal_proportions(self):
|
||||||
|
"""Compute temporal proportions for each subtask from dataset annotations."""
|
||||||
|
if self.dataset_meta is None or not hasattr(self.dataset_meta, 'episodes'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if subtask annotations exist
|
||||||
|
episodes = self.dataset_meta.episodes
|
||||||
|
if episodes is None or len(episodes) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for subtask_names column
|
||||||
|
if 'subtask_names' not in episodes.column_names:
|
||||||
|
logging.info("No subtask annotations found in dataset")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to pandas for easier processing
|
||||||
|
import pandas as pd
|
||||||
|
episodes_df = episodes.to_pandas()
|
||||||
|
|
||||||
|
# Collect all subtask names and compute average durations
|
||||||
|
subtask_durations = {}
|
||||||
|
subtask_counts = {}
|
||||||
|
all_subtask_names = set()
|
||||||
|
|
||||||
|
for ep_idx in episodes_df.index:
|
||||||
|
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||||
|
|
||||||
|
# Skip episodes without annotations
|
||||||
|
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_times = episodes_df.loc[ep_idx, 'subtask_start_times']
|
||||||
|
end_times = episodes_df.loc[ep_idx, 'subtask_end_times']
|
||||||
|
|
||||||
|
# Track unique subtask names
|
||||||
|
all_subtask_names.update(subtask_names)
|
||||||
|
|
||||||
|
# Compute durations
|
||||||
|
for i, name in enumerate(subtask_names):
|
||||||
|
duration = end_times[i] - start_times[i]
|
||||||
|
if name not in subtask_durations:
|
||||||
|
subtask_durations[name] = []
|
||||||
|
subtask_durations[name].append(duration)
|
||||||
|
|
||||||
|
if not all_subtask_names:
|
||||||
|
logging.info("No valid subtask annotations found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sort subtask names for consistent ordering
|
||||||
|
self.subtask_names = sorted(list(all_subtask_names))
|
||||||
|
self.config.num_stages = len(self.subtask_names)
|
||||||
|
self.config.subtask_names = self.subtask_names # Store in config for reference
|
||||||
|
|
||||||
|
# Compute average duration for each subtask
|
||||||
|
avg_durations = {}
|
||||||
|
for name in self.subtask_names:
|
||||||
|
if name in subtask_durations:
|
||||||
|
avg_durations[name] = np.mean(subtask_durations[name])
|
||||||
|
else:
|
||||||
|
avg_durations[name] = 0.0
|
||||||
|
|
||||||
|
# Normalize to get proportions
|
||||||
|
total_duration = sum(avg_durations.values())
|
||||||
|
if total_duration > 0:
|
||||||
|
self.temporal_proportions = {
|
||||||
|
name: avg_durations[name] / total_duration
|
||||||
|
for name in self.subtask_names
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Equal proportions if no duration info
|
||||||
|
self.temporal_proportions = {
|
||||||
|
name: 1.0 / len(self.subtask_names)
|
||||||
|
for name in self.subtask_names
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
||||||
|
|
||||||
|
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||||
|
"""Generate stage labels and refined progress targets from subtask annotations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_index: Current frame index or indices
|
||||||
|
episode_index: Episode index
|
||||||
|
video_features: Video features tensor to determine sequence length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (stage_labels, progress_targets) or (None, None) if no annotations
|
||||||
|
"""
|
||||||
|
if self.temporal_proportions is None or episode_index is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Convert to pandas to access annotations
|
||||||
|
import pandas as pd
|
||||||
|
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||||
|
|
||||||
|
# Handle batch processing
|
||||||
|
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
# Process multiple samples - for now, return None
|
||||||
|
# (batch processing of annotations is complex and not critical)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Single sample processing
|
||||||
|
if isinstance(episode_index, torch.Tensor):
|
||||||
|
ep_idx = int(episode_index.item())
|
||||||
|
else:
|
||||||
|
ep_idx = int(episode_index)
|
||||||
|
|
||||||
|
if isinstance(frame_index, torch.Tensor):
|
||||||
|
frame_idx = int(frame_index.item())
|
||||||
|
else:
|
||||||
|
frame_idx = int(frame_index)
|
||||||
|
|
||||||
|
# Get subtask annotations for this episode
|
||||||
|
if ep_idx >= len(episodes_df):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||||
|
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||||
|
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||||
|
|
||||||
|
# Get episode boundaries
|
||||||
|
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||||
|
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||||
|
|
||||||
|
# Determine sequence length
|
||||||
|
if video_features is not None and video_features.dim() > 0:
|
||||||
|
seq_len = video_features.shape[0] if video_features.dim() == 2 else video_features.shape[1]
|
||||||
|
else:
|
||||||
|
seq_len = 1
|
||||||
|
|
||||||
|
# Generate labels for each frame in the sequence
|
||||||
|
stage_labels = []
|
||||||
|
progress_targets = []
|
||||||
|
|
||||||
|
# Get frame gap for temporal sampling
|
||||||
|
frame_gap = self.config.frame_gap if hasattr(self.config, 'frame_gap') else 1
|
||||||
|
|
||||||
|
for i in range(seq_len):
|
||||||
|
# Calculate actual frame index for this position in sequence
|
||||||
|
if frame_gap > 1:
|
||||||
|
offset = -(seq_len - 1 - i) * frame_gap
|
||||||
|
current_frame = max(0, frame_idx + offset - ep_start)
|
||||||
|
else:
|
||||||
|
current_frame = max(0, frame_idx - seq_len + 1 + i - ep_start)
|
||||||
|
|
||||||
|
# Find which subtask this frame belongs to
|
||||||
|
stage_idx = -1
|
||||||
|
within_subtask_progress = 0.0
|
||||||
|
cumulative_progress = 0.0
|
||||||
|
|
||||||
|
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||||
|
if current_frame >= start_frame and current_frame <= end_frame:
|
||||||
|
# Found the subtask
|
||||||
|
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||||
|
|
||||||
|
# Calculate within-subtask progress
|
||||||
|
subtask_duration = end_frame - start_frame
|
||||||
|
if subtask_duration > 0:
|
||||||
|
within_subtask_progress = (current_frame - start_frame) / subtask_duration
|
||||||
|
else:
|
||||||
|
within_subtask_progress = 1.0
|
||||||
|
|
||||||
|
# Calculate cumulative progress
|
||||||
|
for k in range(j):
|
||||||
|
prev_name = subtask_names[k]
|
||||||
|
if prev_name in self.temporal_proportions:
|
||||||
|
cumulative_progress += self.temporal_proportions[prev_name]
|
||||||
|
|
||||||
|
# Add current subtask's partial progress
|
||||||
|
if name in self.temporal_proportions:
|
||||||
|
cumulative_progress += self.temporal_proportions[name] * within_subtask_progress
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
# If no matching subtask found, estimate based on position
|
||||||
|
if stage_idx == -1:
|
||||||
|
# Estimate stage based on frame position
|
||||||
|
if current_frame < subtask_start_frames[0]:
|
||||||
|
stage_idx = 0
|
||||||
|
cumulative_progress = 0.0
|
||||||
|
elif current_frame > subtask_end_frames[-1]:
|
||||||
|
stage_idx = len(self.subtask_names) - 1
|
||||||
|
cumulative_progress = 1.0
|
||||||
|
else:
|
||||||
|
# Between subtasks - use previous subtask's end state
|
||||||
|
for j in range(len(subtask_names) - 1):
|
||||||
|
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||||
|
name = subtask_names[j]
|
||||||
|
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||||
|
# Sum up all previous subtasks
|
||||||
|
for k in range(j + 1):
|
||||||
|
prev_name = subtask_names[k]
|
||||||
|
if prev_name in self.temporal_proportions:
|
||||||
|
cumulative_progress += self.temporal_proportions[prev_name]
|
||||||
|
break
|
||||||
|
|
||||||
|
stage_labels.append(stage_idx)
|
||||||
|
progress_targets.append(cumulative_progress)
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||||
|
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1) # Add channel dim
|
||||||
|
|
||||||
|
return stage_labels, progress_targets
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
"""Encode images, text, and normalize states in the transition."""
|
"""Encode images, text, and normalize states in the transition."""
|
||||||
from lerobot.processor.core import TransitionKey
|
from lerobot.processor.core import TransitionKey
|
||||||
@@ -351,6 +567,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
observation['episode_length'] = episode_length
|
observation['episode_length'] = episode_length
|
||||||
|
|
||||||
|
# Generate stage labels and refined progress from subtask annotations
|
||||||
|
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||||
|
stage_labels, progress_targets = self._generate_stage_and_progress_labels(
|
||||||
|
frame_index, episode_index, observation.get('video_features')
|
||||||
|
)
|
||||||
|
if stage_labels is not None:
|
||||||
|
observation['stage_labels'] = stage_labels
|
||||||
|
observation['progress_targets'] = progress_targets
|
||||||
|
|
||||||
new_transition[TransitionKey.OBSERVATION] = observation
|
new_transition[TransitionKey.OBSERVATION] = observation
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user