mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
simplify and cleanup code and move compute_temporal_proportions to utils
This commit is contained in:
@@ -72,13 +72,11 @@ import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -86,24 +84,8 @@ from rich.tree import Tree
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Pydantic Models for SARM-style Annotation
|
||||
class Timestamp(BaseModel):
|
||||
"""Timestamp in MM:SS or SS format"""
|
||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||
|
||||
|
||||
class Subtask(BaseModel):
|
||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||
timestamps: Timestamp
|
||||
|
||||
|
||||
class SubtaskAnnotation(BaseModel):
|
||||
"""Complete annotation for a robot manipulation episode"""
|
||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
|
||||
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||||
"""
|
||||
@@ -769,59 +751,6 @@ def worker_process_episodes(
|
||||
return annotations
|
||||
|
||||
|
||||
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
||||
"""
|
||||
Compute average temporal proportion for each subtask across all episodes.
|
||||
This is the key insight from SARM - use semantic subtasks instead of frame indices.
|
||||
"""
|
||||
# Collect all proportions per subtask
|
||||
subtask_proportions = {}
|
||||
|
||||
for annotation in annotations.values():
|
||||
# Calculate total episode duration
|
||||
total_duration = 0
|
||||
durations = {}
|
||||
|
||||
for subtask in annotation.subtasks:
|
||||
# Parse timestamps
|
||||
start_parts = subtask.timestamps.start.split(":")
|
||||
end_parts = subtask.timestamps.end.split(":")
|
||||
|
||||
if len(start_parts) == 2:
|
||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
|
||||
else:
|
||||
start_seconds = int(start_parts[0])
|
||||
|
||||
if len(end_parts) == 2:
|
||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
|
||||
else:
|
||||
end_seconds = int(end_parts[0])
|
||||
|
||||
duration = end_seconds - start_seconds
|
||||
durations[subtask.name] = duration
|
||||
total_duration += duration
|
||||
|
||||
# Calculate proportions for this episode
|
||||
if total_duration > 0:
|
||||
for name, duration in durations.items():
|
||||
if name not in subtask_proportions:
|
||||
subtask_proportions[name] = []
|
||||
subtask_proportions[name].append(duration / total_duration)
|
||||
|
||||
# Average across episodes
|
||||
avg_proportions = {
|
||||
name: sum(props) / len(props)
|
||||
for name, props in subtask_proportions.items()
|
||||
}
|
||||
|
||||
# Normalize to sum to 1.0
|
||||
total = sum(avg_proportions.values())
|
||||
if total > 0:
|
||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||
|
||||
return avg_proportions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
||||
@@ -1185,4 +1114,3 @@ Performance Tips:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user