simplify and cleanup code and move compute_temporal_proportions to utils

This commit is contained in:
Pepijn
2025-11-27 19:38:32 +01:00
parent 73dd4f10f7
commit adc476d8af
5 changed files with 138 additions and 254 deletions
@@ -72,13 +72,11 @@ import argparse
import json
import time
from pathlib import Path
from typing import Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import pandas as pd
import torch
from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info
from rich.console import Console
from rich.panel import Panel
@@ -86,24 +84,8 @@ from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def create_sarm_prompt(subtask_list: list[str]) -> str:
"""
@@ -769,59 +751,6 @@ def worker_process_episodes(
return annotations
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
"""
Compute average temporal proportion for each subtask across all episodes.
This is the key insight from SARM - use semantic subtasks instead of frame indices.
"""
# Collect all proportions per subtask
subtask_proportions = {}
for annotation in annotations.values():
# Calculate total episode duration
total_duration = 0
durations = {}
for subtask in annotation.subtasks:
# Parse timestamps
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
if len(start_parts) == 2:
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
else:
start_seconds = int(start_parts[0])
if len(end_parts) == 2:
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
else:
end_seconds = int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate proportions for this episode
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
# Average across episodes
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to sum to 1.0
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def main():
parser = argparse.ArgumentParser(
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
@@ -1185,4 +1114,3 @@ Performance Tips:
if __name__ == "__main__":
main()