simplify and cleanup code and move compute_temporal_proportions to utils

This commit is contained in:
Pepijn
2025-11-27 19:38:32 +01:00
parent 73dd4f10f7
commit adc476d8af
5 changed files with 138 additions and 254 deletions
@@ -72,13 +72,11 @@ import argparse
import json
import time
from pathlib import Path
from typing import Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import pandas as pd
import torch
from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info
from rich.console import Console
from rich.panel import Panel
@@ -86,24 +84,8 @@ from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def create_sarm_prompt(subtask_list: list[str]) -> str:
"""
@@ -769,59 +751,6 @@ def worker_process_episodes(
return annotations
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
"""
Compute average temporal proportion for each subtask across all episodes.
This is the key insight from SARM - use semantic subtasks instead of frame indices.
"""
# Collect all proportions per subtask
subtask_proportions = {}
for annotation in annotations.values():
# Calculate total episode duration
total_duration = 0
durations = {}
for subtask in annotation.subtasks:
# Parse timestamps
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
if len(start_parts) == 2:
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
else:
start_seconds = int(start_parts[0])
if len(end_parts) == 2:
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
else:
end_seconds = int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate proportions for this episode
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
# Average across episodes
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to sum to 1.0
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def main():
parser = argparse.ArgumentParser(
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
@@ -1185,4 +1114,3 @@ Performance Tips:
if __name__ == "__main__":
main()
+1 -1
View File
@@ -27,7 +27,7 @@ from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module):
+12 -51
View File
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any
import numpy as np
import torch
@@ -23,6 +21,7 @@ from PIL import Image
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
from lerobot.processor.core import TransitionKey
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import (
@@ -119,7 +118,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
frame_indices: np.ndarray,
episode_indices: np.ndarray,
num_frames: int,
is_batch: bool,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples.
@@ -140,10 +138,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item())
if is_batch:
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
else:
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
def _compute_stage_and_progress_for_frame(
self,
@@ -189,7 +184,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list
)
return stage_idx, cumulative_progress
# No matching subtask found
@@ -288,15 +282,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
if self.temporal_proportions is None or episode_index is None:
return None, None
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
# Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length
if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1] if is_batch else video_features.shape[0]
seq_len = video_features.shape[1]
else:
seq_len = 1
@@ -315,24 +307,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
all_stage_labels.append(result[0])
all_progress_targets.append(result[1])
if is_batch:
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
return all_stage_labels[0], all_progress_targets[0]
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition."""
from lerobot.processor.core import TransitionKey
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("Observation must be a dictionary")
# Encode images with CLIP
image = observation.get(self.image_key)
if image is None:
raise ValueError(f"Image not found in observation for key: {self.image_key}")
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
@@ -342,10 +325,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key
state_data = observation.get(state_key)
if state_data is None:
state_data = observation.get("state") or observation.get("observation.state")
if state_data is None:
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float()
@@ -355,8 +334,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
# Get task description from dataset (complementary_data["task"])
task = comp_data.get('task')
@@ -378,18 +355,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None:
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features
if video_features.dim() >= 2:
num_frames = video_features.shape[1] if is_batch else video_features.shape[0]
num_frames = video_features.shape[1]
else:
num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames, is_batch
frame_indices, episode_indices, num_frames
)
observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining
@@ -412,24 +388,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
"""Encode a batch of images using CLIP.
Args:
images: Batched images with shape:
- (B, C, H, W) for single frames, or
- (B, T, C, H, W) for temporal sequences
images: Batched images with shape: (B, T, C, H, W)
Returns:
Encoded feature vectors with shape (B, 512) or (B, T, 512)
Encoded feature vectors with shape (B, T, 512)
"""
# Check if we have temporal dimension
has_temporal = len(images.shape) == 5
if has_temporal: # Shape: (B, T, C, H, W)
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
elif len(images.shape) == 4: # Shape: (B, C, H, W)
batch_size = images.shape[0]
seq_length = 1
else:
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:])
# Convert to list of PIL images
num_frames = images.shape[0]
@@ -470,9 +436,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
# Reshape back if temporal
if has_temporal:
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
# Reshape back
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings
@@ -495,10 +460,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
# Handle single text case
if text_embedding.dim() == 1:
text_embedding = text_embedding.unsqueeze(0)
# Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1)
+63 -54
View File
@@ -14,25 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions for SARM progress label computation.
Implements formulas from the SARM paper:
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
"""
import numpy as np
import torch
import torch.nn.functional as F
from typing import Sequence
from typing import Sequence, Any
from pydantic import BaseModel, Field
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
def compute_priors(
subtask_durations_per_trajectory: dict[str, list[float]],
trajectory_lengths: dict[str, list[float]],
subtask_names: list[str],
) -> dict[str, float]:
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
"""
Compute dataset-level temporal proportions (priors) for each subtask.
@@ -40,61 +46,64 @@ def compute_priors(
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where:
- M is the number of trajectories
- L_{i,k} is the length of subtask k in trajectory i
- T_i is the total length of trajectory i
- M is the number of trajectories (episodes)
- L_{i,k} is the duration of subtask k in trajectory i
- T_i is the total duration of trajectory i
This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length.
Args:
subtask_durations_per_trajectory: Dict mapping subtask name to list of
(duration, trajectory_length) tuples for each occurrence
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
for each occurrence of that subtask
subtask_names: Ordered list of subtask names
annotations: Dict mapping episode index to SubtaskAnnotation object.
Each annotation has a .subtasks list where each subtask has:
- .name: subtask name
- .timestamps.start: start time as "MM:SS" string
- .timestamps.end: end time as "MM:SS" string
fps: Frames per second (unused, kept for API compatibility)
Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k)
Dict mapping subtask name to its temporal proportion (ᾱ_k).
Proportions are normalized to sum to 1.0.
"""
if not subtask_names:
raise ValueError("subtask_names cannot be empty")
subtask_proportions: dict[str, list[float]] = {}
# Compute proportion per occurrence: L_{i,k} / T_i
subtask_proportions = {}
for name in subtask_names:
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
durations = subtask_durations_per_trajectory[name]
traj_lengths = trajectory_lengths[name]
for annotation in annotations.values():
total_duration = 0
durations: dict[str, int] = {}
if len(durations) != len(traj_lengths):
raise ValueError(
f"Mismatch in lengths for subtask '{name}': "
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
)
for subtask in annotation.subtasks:
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
# Compute L_{i,k} / T_i for each occurrence
proportions = []
for duration, traj_len in zip(durations, traj_lengths):
if traj_len > 0:
proportions.append(duration / traj_len)
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
else:
subtask_proportions[name] = 0.0
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
total = sum(subtask_proportions.values())
# Calculate L_{i,k} / T_i for each subtask in this trajectory
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to ensure sum = 1
total = sum(avg_proportions.values())
if total > 0:
subtask_proportions = {
name: prop / total for name, prop in subtask_proportions.items()
}
else:
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
"Check that your dataset has valid subtask annotations with start/end times.")
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return subtask_proportions
return avg_proportions
def compute_tau(
+56 -70
View File
@@ -18,7 +18,7 @@
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_priors - dataset-level temporal proportions
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
@@ -26,15 +26,31 @@ import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
from lerobot.policies.sarm.sarm_utils import (
compute_priors,
compute_temporal_proportions,
compute_tau,
compute_cumulative_progress_batch,
)
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
return SubtaskAnnotation(
subtasks=[
Subtask(
name=name,
timestamps=Timestamp(
start=f"{start // 60:02d}:{start % 60:02d}",
end=f"{end // 60:02d}:{end % 60:02d}"
)
)
for name, start, end in subtasks
]
)
class TestComputePriors:
"""Tests for compute_priors (SARM Paper Formula 1).
class TestComputeTemporalProportions:
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
@@ -45,42 +61,34 @@ class TestComputePriors:
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
subtask_durations = {
'subtask1': [50, 100], # durations
'subtask2': [50, 100],
# Traj 1: T=100s, subtask1=50s, subtask2=50s
# Traj 2: T=200s, subtask1=100s, subtask2=100s
annotations = {
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
result = compute_temporal_proportions(annotations)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_priors differs from naive average duration approach.
"""Test that compute_temporal_proportions differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
subtask_durations = {
'subtask1': [80, 40],
'subtask2': [20, 160],
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
annotations = {
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
result = compute_temporal_proportions(annotations)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
@@ -88,22 +96,14 @@ class TestComputePriors:
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
subtask_durations = {
'reach': [30],
'grasp': [20],
'lift': [50],
# T=100s, reach=30s, grasp=20s, lift=50s
annotations = {
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
}
trajectory_lengths = {
'reach': [100],
'grasp': [100],
'lift': [100],
}
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
result = compute_temporal_proportions(annotations)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
@@ -111,49 +111,35 @@ class TestComputePriors:
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
subtask_durations = {
'a': [10, 20, 30],
'b': [40, 50, 60],
'c': [50, 30, 10],
# Three episodes with varying proportions
annotations = {
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
}
trajectory_lengths = {
'a': [100, 100, 100],
'b': [100, 100, 100],
'c': [100, 100, 100],
}
subtask_names = ['a', 'b', 'c']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
result = compute_temporal_proportions(annotations)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_subtask_names_raises(self):
"""Test that empty subtask_names raises an error."""
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
compute_priors({}, {}, [])
def test_empty_annotations_returns_empty(self):
"""Test that empty annotations returns empty dict."""
result = compute_temporal_proportions({})
assert result == {}
def test_missing_subtask_gets_zero_before_normalization(self):
"""Test handling of subtasks that appear in some but not all trajectories."""
# subtask1 appears in both, subtask2 only in first
subtask_durations = {
'subtask1': [50, 100],
'subtask2': [50], # only in first trajectory
def test_uniform_proportions(self):
"""Test with uniform proportions across subtasks."""
# Each subtask takes 25% of each episode
annotations = {
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
result = compute_temporal_proportions(annotations)
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
# subtask2: 50/100 = 0.5 (only one occurrence)
# After normalization: both should be 0.5
assert result['subtask1'] > 0
assert result['subtask2'] > 0
assert abs(sum(result.values()) - 1.0) < 1e-6
for name in ['a', 'b', 'c', 'd']:
assert abs(result[name] - 0.25) < 1e-6
class TestComputeTau: