simplify and cleanup code and move compute_temporal_proportions to utils

This commit is contained in:
Pepijn
2025-11-27 19:38:32 +01:00
parent 73dd4f10f7
commit adc476d8af
5 changed files with 138 additions and 254 deletions
@@ -72,13 +72,11 @@ import argparse
import json import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp import multiprocessing as mp
import pandas as pd import pandas as pd
import torch import torch
from pydantic import BaseModel, Field
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@@ -86,24 +84,8 @@ from rich.tree import Tree
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
# Pydantic Models for SARM-style Annotation from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
class Subtask(BaseModel):
"""Individual subtask/stage - must use EXACT names from provided list"""
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
timestamps: Timestamp
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def create_sarm_prompt(subtask_list: list[str]) -> str: def create_sarm_prompt(subtask_list: list[str]) -> str:
""" """
@@ -769,59 +751,6 @@ def worker_process_episodes(
return annotations return annotations
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
"""
Compute average temporal proportion for each subtask across all episodes.
This is the key insight from SARM - use semantic subtasks instead of frame indices.
"""
# Collect all proportions per subtask
subtask_proportions = {}
for annotation in annotations.values():
# Calculate total episode duration
total_duration = 0
durations = {}
for subtask in annotation.subtasks:
# Parse timestamps
start_parts = subtask.timestamps.start.split(":")
end_parts = subtask.timestamps.end.split(":")
if len(start_parts) == 2:
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
else:
start_seconds = int(start_parts[0])
if len(end_parts) == 2:
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
else:
end_seconds = int(end_parts[0])
duration = end_seconds - start_seconds
durations[subtask.name] = duration
total_duration += duration
# Calculate proportions for this episode
if total_duration > 0:
for name, duration in durations.items():
if name not in subtask_proportions:
subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
# Average across episodes
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
}
# Normalize to sum to 1.0
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="SARM-style subtask annotation using local GPU (Qwen3-VL)", description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
@@ -1185,4 +1114,3 @@ Performance Tips:
if __name__ == "__main__": if __name__ == "__main__":
main() main()
+1 -1
View File
@@ -27,7 +27,7 @@ from transformers import CLIPModel, CLIPProcessor
from torch import Tensor from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
class SARMTransformer(nn.Module): class SARMTransformer(nn.Module):
+7 -46
View File
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging
from typing import Any from typing import Any
import numpy as np import numpy as np
import torch import torch
@@ -23,6 +21,7 @@ from PIL import Image
import pandas as pd import pandas as pd
from transformers import CLIPModel, CLIPProcessor from transformers import CLIPModel, CLIPProcessor
from lerobot.processor.core import TransitionKey
from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
from lerobot.processor import ( from lerobot.processor import (
@@ -119,7 +118,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
frame_indices: np.ndarray, frame_indices: np.ndarray,
episode_indices: np.ndarray, episode_indices: np.ndarray,
num_frames: int, num_frames: int,
is_batch: bool,
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute episode metadata for all samples. """Compute episode metadata for all samples.
@@ -140,10 +138,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
absolute_indices_list.append(abs_indices) absolute_indices_list.append(abs_indices)
remaining_lengths.append(ep_end - abs_indices[0].item()) remaining_lengths.append(ep_end - abs_indices[0].item())
if is_batch:
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths) return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
else:
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
def _compute_stage_and_progress_for_frame( def _compute_stage_and_progress_for_frame(
self, self,
@@ -189,7 +184,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
cumulative_progress = compute_cumulative_progress_batch( cumulative_progress = compute_cumulative_progress_batch(
tau, stage_idx, temporal_proportions_list tau, stage_idx, temporal_proportions_list
) )
return stage_idx, cumulative_progress return stage_idx, cumulative_progress
# No matching subtask found # No matching subtask found
@@ -288,15 +282,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
if self.temporal_proportions is None or episode_index is None: if self.temporal_proportions is None or episode_index is None:
return None, None return None, None
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
# Normalize inputs to numpy arrays # Normalize inputs to numpy arrays
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index) episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine sequence length # Determine sequence length
if video_features is not None and video_features.dim() >= 2: if video_features is not None and video_features.dim() >= 2:
seq_len = video_features.shape[1] if is_batch else video_features.shape[0] seq_len = video_features.shape[1]
else: else:
seq_len = 1 seq_len = 1
@@ -315,24 +307,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
all_stage_labels.append(result[0]) all_stage_labels.append(result[0])
all_progress_targets.append(result[1]) all_progress_targets.append(result[1])
if is_batch:
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0) return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
return all_stage_labels[0], all_progress_targets[0]
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images, text, and normalize states in the transition.""" """Encode images, text, and normalize states in the transition."""
from lerobot.processor.core import TransitionKey
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
observation = new_transition.get(TransitionKey.OBSERVATION) observation = new_transition.get(TransitionKey.OBSERVATION)
if not isinstance(observation, dict):
raise ValueError("Observation must be a dictionary")
# Encode images with CLIP
image = observation.get(self.image_key) image = observation.get(self.image_key)
if image is None:
raise ValueError(f"Image not found in observation for key: {self.image_key}")
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
image = image.cpu().numpy() image = image.cpu().numpy()
@@ -342,10 +325,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep) # Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
state_key = self.config.state_key state_key = self.config.state_key
state_data = observation.get(state_key) state_data = observation.get(state_key)
if state_data is None:
state_data = observation.get("state") or observation.get("observation.state")
if state_data is None:
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
if isinstance(state_data, torch.Tensor): if isinstance(state_data, torch.Tensor):
state_tensor = state_data.float() state_tensor = state_data.float()
@@ -355,8 +334,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if not isinstance(comp_data, dict):
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
# Get task description from dataset (complementary_data["task"]) # Get task description from dataset (complementary_data["task"])
task = comp_data.get('task') task = comp_data.get('task')
@@ -378,18 +355,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Compute episode metadata if dataset_meta is available # Compute episode metadata if dataset_meta is available
if self.dataset_meta is not None: if self.dataset_meta is not None:
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
episode_indices = self._get_episode_indices(frame_indices, episode_index) episode_indices = self._get_episode_indices(frame_indices, episode_index)
# Determine number of frames from video features # Determine number of frames from video features
if video_features.dim() >= 2: if video_features.dim() >= 2:
num_frames = video_features.shape[1] if is_batch else video_features.shape[0] num_frames = video_features.shape[1]
else: else:
num_frames = 1 num_frames = 1
abs_indices, remaining, ep_lengths = self._compute_episode_metadata( abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
frame_indices, episode_indices, num_frames, is_batch frame_indices, episode_indices, num_frames
) )
observation['absolute_frame_indices'] = abs_indices observation['absolute_frame_indices'] = abs_indices
observation['remaining_length'] = remaining observation['remaining_length'] = remaining
@@ -412,24 +388,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
"""Encode a batch of images using CLIP. """Encode a batch of images using CLIP.
Args: Args:
images: Batched images with shape: images: Batched images with shape: (B, T, C, H, W)
- (B, C, H, W) for single frames, or
- (B, T, C, H, W) for temporal sequences
Returns: Returns:
Encoded feature vectors with shape (B, 512) or (B, T, 512) Encoded feature vectors with shape (B, T, 512)
""" """
# Check if we have temporal dimension
has_temporal = len(images.shape) == 5
if has_temporal: # Shape: (B, T, C, H, W)
batch_size, seq_length = images.shape[0], images.shape[1] batch_size, seq_length = images.shape[0], images.shape[1]
images = images.reshape(batch_size * seq_length, *images.shape[2:]) images = images.reshape(batch_size * seq_length, *images.shape[2:])
elif len(images.shape) == 4: # Shape: (B, C, H, W)
batch_size = images.shape[0]
seq_length = 1
else:
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
# Convert to list of PIL images # Convert to list of PIL images
num_frames = images.shape[0] num_frames = images.shape[0]
@@ -470,8 +436,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Concatenate all embeddings # Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 512) all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
# Reshape back if temporal # Reshape back
if has_temporal:
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
return all_embeddings return all_embeddings
@@ -495,10 +460,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
# Get text features from CLIP # Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
# Handle single text case
if text_embedding.dim() == 1:
text_embedding = text_embedding.unsqueeze(0)
# Replicate for batch (B, 512) # Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1) text_embedding = text_embedding.expand(batch_size, -1)
+63 -54
View File
@@ -14,25 +14,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Utility functions for SARM progress label computation.
Implements formulas from the SARM paper:
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
"""
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Sequence from typing import Sequence, Any
from pydantic import BaseModel, Field
# Pydantic Models for SARM-style Annotation
class Timestamp(BaseModel):
"""Timestamp in MM:SS or SS format"""
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
end: str = Field(description="End timestamp (MM:SS or just seconds)")
def compute_priors( class Subtask(BaseModel):
subtask_durations_per_trajectory: dict[str, list[float]], """Individual subtask/stage - must use EXACT names from provided list"""
trajectory_lengths: dict[str, list[float]], name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
subtask_names: list[str], timestamps: Timestamp
) -> dict[str, float]:
class SubtaskAnnotation(BaseModel):
"""Complete annotation for a robot manipulation episode"""
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
""" """
Compute dataset-level temporal proportions (priors) for each subtask. Compute dataset-level temporal proportions (priors) for each subtask.
@@ -40,61 +46,64 @@ def compute_priors(
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
where: where:
- M is the number of trajectories - M is the number of trajectories (episodes)
- L_{i,k} is the length of subtask k in trajectory i - L_{i,k} is the duration of subtask k in trajectory i
- T_i is the total length of trajectory i - T_i is the total duration of trajectory i
This averages the PROPORTION of each subtask within each trajectory, This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of their absolute length. giving equal weight to all trajectories regardless of their absolute length.
Args: Args:
subtask_durations_per_trajectory: Dict mapping subtask name to list of annotations: Dict mapping episode index to SubtaskAnnotation object.
(duration, trajectory_length) tuples for each occurrence Each annotation has a .subtasks list where each subtask has:
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths - .name: subtask name
for each occurrence of that subtask - .timestamps.start: start time as "MM:SS" string
subtask_names: Ordered list of subtask names - .timestamps.end: end time as "MM:SS" string
fps: Frames per second (unused, kept for API compatibility)
Returns: Returns:
Dict mapping subtask name to its temporal proportion (ᾱ_k) Dict mapping subtask name to its temporal proportion (ᾱ_k).
Proportions are normalized to sum to 1.0.
""" """
if not subtask_names: subtask_proportions: dict[str, list[float]] = {}
raise ValueError("subtask_names cannot be empty")
# Compute proportion per occurrence: L_{i,k} / T_i for annotation in annotations.values():
subtask_proportions = {} total_duration = 0
for name in subtask_names: durations: dict[str, int] = {}
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
durations = subtask_durations_per_trajectory[name]
traj_lengths = trajectory_lengths[name]
if len(durations) != len(traj_lengths): for subtask in annotation.subtasks:
raise ValueError( start_parts = subtask.timestamps.start.split(":")
f"Mismatch in lengths for subtask '{name}': " end_parts = subtask.timestamps.end.split(":")
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
)
# Compute L_{i,k} / T_i for each occurrence start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
proportions = [] end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
for duration, traj_len in zip(durations, traj_lengths):
if traj_len > 0:
proportions.append(duration / traj_len)
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i) duration = end_seconds - start_seconds
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0 durations[subtask.name] = duration
else: total_duration += duration
subtask_proportions[name] = 0.0
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks) # Calculate L_{i,k} / T_i for each subtask in this trajectory
total = sum(subtask_proportions.values()) if total_duration > 0:
if total > 0: for name, duration in durations.items():
subtask_proportions = { if name not in subtask_proportions:
name: prop / total for name, prop in subtask_proportions.items() subtask_proportions[name] = []
subtask_proportions[name].append(duration / total_duration)
if not subtask_proportions:
return {}
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
avg_proportions = {
name: sum(props) / len(props)
for name, props in subtask_proportions.items()
} }
else:
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
"Check that your dataset has valid subtask annotations with start/end times.")
return subtask_proportions # Normalize to ensure sum = 1
total = sum(avg_proportions.values())
if total > 0:
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
return avg_proportions
def compute_tau( def compute_tau(
+56 -70
View File
@@ -18,7 +18,7 @@
Tests for SARM utility functions. Tests for SARM utility functions.
Tests the implementation of SARM paper formulas: Tests the implementation of SARM paper formulas:
- Formula (1): compute_priors - dataset-level temporal proportions - Formula (1): compute_temporal_proportions - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels - Formula (2): compute_tau, compute_cumulative_progress - progress labels
""" """
@@ -26,15 +26,31 @@ import pytest
import numpy as np import numpy as np
import torch import torch
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
from lerobot.policies.sarm.sarm_utils import ( from lerobot.policies.sarm.sarm_utils import (
compute_priors, compute_temporal_proportions,
compute_tau, compute_tau,
compute_cumulative_progress_batch, compute_cumulative_progress_batch,
) )
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
return SubtaskAnnotation(
subtasks=[
Subtask(
name=name,
timestamps=Timestamp(
start=f"{start // 60:02d}:{start % 60:02d}",
end=f"{end // 60:02d}:{end % 60:02d}"
)
)
for name, start, end in subtasks
]
)
class TestComputePriors:
"""Tests for compute_priors (SARM Paper Formula 1). class TestComputeTemporalProportions:
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
@@ -45,42 +61,34 @@ class TestComputePriors:
def test_basic_two_trajectories_equal_proportions(self): def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions.""" """Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50% # Both trajectories: subtask1 = 50%, subtask2 = 50%
subtask_durations = { # Traj 1: T=100s, subtask1=50s, subtask2=50s
'subtask1': [50, 100], # durations # Traj 2: T=200s, subtask1=100s, subtask2=100s
'subtask2': [50, 100], annotations = {
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
} }
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) result = compute_temporal_proportions(annotations)
# Both should be 0.5 # Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6 assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6 assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self): def test_paper_example_different_from_avg_durations(self):
"""Test that compute_priors differs from naive average duration approach. """Test that compute_temporal_proportions differs from naive average duration approach.
This is the key test showing the difference between: This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i) - Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j)) - Naive approach: mean(L_i,k) / sum(mean(L_i,j))
""" """
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2) # Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8) # Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
subtask_durations = { annotations = {
'subtask1': [80, 40], 0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
'subtask2': [20, 160], 1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
} }
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) result = compute_temporal_proportions(annotations)
# Paper formula: # Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5 # ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
@@ -88,22 +96,14 @@ class TestComputePriors:
assert abs(result['subtask1'] - 0.5) < 1e-6 assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6 assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self): def test_single_trajectory(self):
"""Test with a single trajectory.""" """Test with a single trajectory."""
subtask_durations = { # T=100s, reach=30s, grasp=20s, lift=50s
'reach': [30], annotations = {
'grasp': [20], 0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
'lift': [50],
} }
trajectory_lengths = {
'reach': [100],
'grasp': [100],
'lift': [100],
}
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) result = compute_temporal_proportions(annotations)
assert abs(result['reach'] - 0.3) < 1e-6 assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6 assert abs(result['grasp'] - 0.2) < 1e-6
@@ -111,49 +111,35 @@ class TestComputePriors:
def test_sum_to_one(self): def test_sum_to_one(self):
"""Test that proportions always sum to 1.""" """Test that proportions always sum to 1."""
subtask_durations = { # Three episodes with varying proportions
'a': [10, 20, 30], annotations = {
'b': [40, 50, 60], 0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
'c': [50, 30, 10], 1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
} }
trajectory_lengths = {
'a': [100, 100, 100],
'b': [100, 100, 100],
'c': [100, 100, 100],
}
subtask_names = ['a', 'b', 'c']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) result = compute_temporal_proportions(annotations)
total = sum(result.values()) total = sum(result.values())
assert abs(total - 1.0) < 1e-6 assert abs(total - 1.0) < 1e-6
def test_empty_subtask_names_raises(self): def test_empty_annotations_returns_empty(self):
"""Test that empty subtask_names raises an error.""" """Test that empty annotations returns empty dict."""
with pytest.raises(ValueError, match="subtask_names cannot be empty"): result = compute_temporal_proportions({})
compute_priors({}, {}, []) assert result == {}
def test_missing_subtask_gets_zero_before_normalization(self): def test_uniform_proportions(self):
"""Test handling of subtasks that appear in some but not all trajectories.""" """Test with uniform proportions across subtasks."""
# subtask1 appears in both, subtask2 only in first # Each subtask takes 25% of each episode
subtask_durations = { annotations = {
'subtask1': [50, 100], 0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
'subtask2': [50], # only in first trajectory 1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
} }
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) result = compute_temporal_proportions(annotations)
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5 for name in ['a', 'b', 'c', 'd']:
# subtask2: 50/100 = 0.5 (only one occurrence) assert abs(result[name] - 0.25) < 1e-6
# After normalization: both should be 0.5
assert result['subtask1'] > 0
assert result['subtask2'] > 0
assert abs(sum(result.values()) - 1.0) < 1e-6
class TestComputeTau: class TestComputeTau: