diff --git a/examples/dataset_annotation/subtask_annotation.py b/examples/dataset_annotation/subtask_annotation.py index d84f85a6a..6c5148923 100644 --- a/examples/dataset_annotation/subtask_annotation.py +++ b/examples/dataset_annotation/subtask_annotation.py @@ -72,13 +72,11 @@ import argparse import json import time from pathlib import Path -from typing import Optional from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp import pandas as pd import torch -from pydantic import BaseModel, Field from qwen_vl_utils import process_vision_info from rich.console import Console from rich.panel import Panel @@ -86,24 +84,8 @@ from rich.tree import Tree from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor from lerobot.datasets.lerobot_dataset import LeRobotDataset - -# Pydantic Models for SARM-style Annotation -class Timestamp(BaseModel): - """Timestamp in MM:SS or SS format""" - start: str = Field(description="Start timestamp (MM:SS or just seconds)") - end: str = Field(description="End timestamp (MM:SS or just seconds)") - - -class Subtask(BaseModel): - """Individual subtask/stage - must use EXACT names from provided list""" - name: str = Field(description="Subtask name - MUST match one from the predefined list exactly") - timestamps: Timestamp - - -class SubtaskAnnotation(BaseModel): - """Complete annotation for a robot manipulation episode""" - subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order") - +from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions +from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp def create_sarm_prompt(subtask_list: list[str]) -> str: """ @@ -769,59 +751,6 @@ def worker_process_episodes( return annotations -def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]: - """ - Compute average temporal proportion for each subtask across all episodes. - This is the key insight from SARM - use semantic subtasks instead of frame indices. - """ - # Collect all proportions per subtask - subtask_proportions = {} - - for annotation in annotations.values(): - # Calculate total episode duration - total_duration = 0 - durations = {} - - for subtask in annotation.subtasks: - # Parse timestamps - start_parts = subtask.timestamps.start.split(":") - end_parts = subtask.timestamps.end.split(":") - - if len(start_parts) == 2: - start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) - else: - start_seconds = int(start_parts[0]) - - if len(end_parts) == 2: - end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) - else: - end_seconds = int(end_parts[0]) - - duration = end_seconds - start_seconds - durations[subtask.name] = duration - total_duration += duration - - # Calculate proportions for this episode - if total_duration > 0: - for name, duration in durations.items(): - if name not in subtask_proportions: - subtask_proportions[name] = [] - subtask_proportions[name].append(duration / total_duration) - - # Average across episodes - avg_proportions = { - name: sum(props) / len(props) - for name, props in subtask_proportions.items() - } - - # Normalize to sum to 1.0 - total = sum(avg_proportions.values()) - if total > 0: - avg_proportions = {name: prop / total for name, prop in avg_proportions.items()} - - return avg_proportions - - def main(): parser = argparse.ArgumentParser( description="SARM-style subtask annotation using local GPU (Qwen3-VL)", @@ -1185,4 +1114,3 @@ Performance Tips: if __name__ == "__main__": main() - diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index fba9d2327..a6e7140b0 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -27,7 +27,7 @@ from transformers import CLIPModel, CLIPProcessor from torch import Tensor from lerobot.policies.sarm.configuration_sarm import SARMConfig -from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim +from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.policies.pretrained import PreTrainedPolicy class SARMTransformer(nn.Module): diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 58d20ae8a..4badb1b86 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import logging from typing import Any import numpy as np import torch @@ -23,6 +21,7 @@ from PIL import Image import pandas as pd from transformers import CLIPModel, CLIPProcessor +from lerobot.processor.core import TransitionKey from lerobot.policies.sarm.configuration_sarm import SARMConfig from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim from lerobot.processor import ( @@ -119,7 +118,6 @@ class SARMEncodingProcessorStep(ProcessorStep): frame_indices: np.ndarray, episode_indices: np.ndarray, num_frames: int, - is_batch: bool, ) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]: """Compute episode metadata for all samples. @@ -140,10 +138,7 @@ class SARMEncodingProcessorStep(ProcessorStep): absolute_indices_list.append(abs_indices) remaining_lengths.append(ep_end - abs_indices[0].item()) - if is_batch: - return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths) - else: - return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0] + return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths) def _compute_stage_and_progress_for_frame( self, @@ -188,8 +183,7 @@ class SARMEncodingProcessorStep(ProcessorStep): # Compute cumulative progress using utility function (Paper Formula 2) cumulative_progress = compute_cumulative_progress_batch( tau, stage_idx, temporal_proportions_list - ) - + ) return stage_idx, cumulative_progress # No matching subtask found @@ -288,15 +282,13 @@ class SARMEncodingProcessorStep(ProcessorStep): if self.temporal_proportions is None or episode_index is None: return None, None - is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 - # Normalize inputs to numpy arrays frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) episode_indices = self._get_episode_indices(frame_indices, episode_index) # Determine sequence length if video_features is not None and video_features.dim() >= 2: - seq_len = video_features.shape[1] if is_batch else video_features.shape[0] + seq_len = video_features.shape[1] else: seq_len = 1 @@ -315,24 +307,15 @@ class SARMEncodingProcessorStep(ProcessorStep): all_stage_labels.append(result[0]) all_progress_targets.append(result[1]) - if is_batch: - return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0) - return all_stage_labels[0], all_progress_targets[0] + return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0) def __call__(self, transition: EnvTransition) -> EnvTransition: """Encode images, text, and normalize states in the transition.""" - from lerobot.processor.core import TransitionKey - + new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition) - observation = new_transition.get(TransitionKey.OBSERVATION) - if not isinstance(observation, dict): - raise ValueError("Observation must be a dictionary") - # Encode images with CLIP image = observation.get(self.image_key) - if image is None: - raise ValueError(f"Image not found in observation for key: {self.image_key}") if isinstance(image, torch.Tensor): image = image.cpu().numpy() @@ -342,10 +325,6 @@ class SARMEncodingProcessorStep(ProcessorStep): # Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep) state_key = self.config.state_key state_data = observation.get(state_key) - if state_data is None: - state_data = observation.get("state") or observation.get("observation.state") - if state_data is None: - raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')") if isinstance(state_data, torch.Tensor): state_tensor = state_data.float() @@ -355,8 +334,6 @@ class SARMEncodingProcessorStep(ProcessorStep): observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim) comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - if not isinstance(comp_data, dict): - raise ValueError("COMPLEMENTARY_DATA must be a dictionary") # Get task description from dataset (complementary_data["task"]) task = comp_data.get('task') @@ -378,18 +355,17 @@ class SARMEncodingProcessorStep(ProcessorStep): # Compute episode metadata if dataset_meta is available if self.dataset_meta is not None: - is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1 frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index))) episode_indices = self._get_episode_indices(frame_indices, episode_index) # Determine number of frames from video features if video_features.dim() >= 2: - num_frames = video_features.shape[1] if is_batch else video_features.shape[0] + num_frames = video_features.shape[1] else: num_frames = 1 abs_indices, remaining, ep_lengths = self._compute_episode_metadata( - frame_indices, episode_indices, num_frames, is_batch + frame_indices, episode_indices, num_frames ) observation['absolute_frame_indices'] = abs_indices observation['remaining_length'] = remaining @@ -412,24 +388,14 @@ class SARMEncodingProcessorStep(ProcessorStep): """Encode a batch of images using CLIP. Args: - images: Batched images with shape: - - (B, C, H, W) for single frames, or - - (B, T, C, H, W) for temporal sequences + images: Batched images with shape: (B, T, C, H, W) Returns: - Encoded feature vectors with shape (B, 512) or (B, T, 512) + Encoded feature vectors with shape (B, T, 512) """ - # Check if we have temporal dimension - has_temporal = len(images.shape) == 5 - - if has_temporal: # Shape: (B, T, C, H, W) - batch_size, seq_length = images.shape[0], images.shape[1] - images = images.reshape(batch_size * seq_length, *images.shape[2:]) - elif len(images.shape) == 4: # Shape: (B, C, H, W) - batch_size = images.shape[0] - seq_length = 1 - else: - raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}") + + batch_size, seq_length = images.shape[0], images.shape[1] + images = images.reshape(batch_size * seq_length, *images.shape[2:]) # Convert to list of PIL images num_frames = images.shape[0] @@ -470,9 +436,8 @@ class SARMEncodingProcessorStep(ProcessorStep): # Concatenate all embeddings all_embeddings = torch.cat(all_embeddings) # (B*T, 512) - # Reshape back if temporal - if has_temporal: - all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) + # Reshape back + all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512) return all_embeddings @@ -495,10 +460,6 @@ class SARMEncodingProcessorStep(ProcessorStep): # Get text features from CLIP text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu() - # Handle single text case - if text_embedding.dim() == 1: - text_embedding = text_embedding.unsqueeze(0) - # Replicate for batch (B, 512) text_embedding = text_embedding.expand(batch_size, -1) diff --git a/src/lerobot/policies/sarm/sarm_utils.py b/src/lerobot/policies/sarm/sarm_utils.py index 307480a8c..ae9c4185a 100644 --- a/src/lerobot/policies/sarm/sarm_utils.py +++ b/src/lerobot/policies/sarm/sarm_utils.py @@ -14,25 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Utility functions for SARM progress label computation. - -Implements formulas from the SARM paper: -- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k -- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t -""" - import numpy as np import torch import torch.nn.functional as F -from typing import Sequence +from typing import Sequence, Any +from pydantic import BaseModel, Field + +# Pydantic Models for SARM-style Annotation +class Timestamp(BaseModel): + """Timestamp in MM:SS or SS format""" + start: str = Field(description="Start timestamp (MM:SS or just seconds)") + end: str = Field(description="End timestamp (MM:SS or just seconds)") -def compute_priors( - subtask_durations_per_trajectory: dict[str, list[float]], - trajectory_lengths: dict[str, list[float]], - subtask_names: list[str], -) -> dict[str, float]: +class Subtask(BaseModel): + """Individual subtask/stage - must use EXACT names from provided list""" + name: str = Field(description="Subtask name - MUST match one from the predefined list exactly") + timestamps: Timestamp + + +class SubtaskAnnotation(BaseModel): + """Complete annotation for a robot manipulation episode""" + subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order") + + +def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]: """ Compute dataset-level temporal proportions (priors) for each subtask. @@ -40,61 +46,64 @@ def compute_priors( ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) where: - - M is the number of trajectories - - L_{i,k} is the length of subtask k in trajectory i - - T_i is the total length of trajectory i + - M is the number of trajectories (episodes) + - L_{i,k} is the duration of subtask k in trajectory i + - T_i is the total duration of trajectory i This averages the PROPORTION of each subtask within each trajectory, giving equal weight to all trajectories regardless of their absolute length. Args: - subtask_durations_per_trajectory: Dict mapping subtask name to list of - (duration, trajectory_length) tuples for each occurrence - trajectory_lengths: Dict mapping subtask name to list of trajectory lengths - for each occurrence of that subtask - subtask_names: Ordered list of subtask names + annotations: Dict mapping episode index to SubtaskAnnotation object. + Each annotation has a .subtasks list where each subtask has: + - .name: subtask name + - .timestamps.start: start time as "MM:SS" string + - .timestamps.end: end time as "MM:SS" string + fps: Frames per second (unused, kept for API compatibility) Returns: - Dict mapping subtask name to its temporal proportion (ᾱ_k) + Dict mapping subtask name to its temporal proportion (ᾱ_k). + Proportions are normalized to sum to 1.0. """ - if not subtask_names: - raise ValueError("subtask_names cannot be empty") + subtask_proportions: dict[str, list[float]] = {} - # Compute proportion per occurrence: L_{i,k} / T_i - subtask_proportions = {} - for name in subtask_names: - if name in subtask_durations_per_trajectory and name in trajectory_lengths: - durations = subtask_durations_per_trajectory[name] - traj_lengths = trajectory_lengths[name] + for annotation in annotations.values(): + total_duration = 0 + durations: dict[str, int] = {} + + for subtask in annotation.subtasks: + start_parts = subtask.timestamps.start.split(":") + end_parts = subtask.timestamps.end.split(":") - if len(durations) != len(traj_lengths): - raise ValueError( - f"Mismatch in lengths for subtask '{name}': " - f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths" - ) + start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0]) + end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0]) - # Compute L_{i,k} / T_i for each occurrence - proportions = [] - for duration, traj_len in zip(durations, traj_lengths): - if traj_len > 0: - proportions.append(duration / traj_len) - - # Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i) - subtask_proportions[name] = np.mean(proportions) if proportions else 0.0 - else: - subtask_proportions[name] = 0.0 + duration = end_seconds - start_seconds + durations[subtask.name] = duration + total_duration += duration + + # Calculate L_{i,k} / T_i for each subtask in this trajectory + if total_duration > 0: + for name, duration in durations.items(): + if name not in subtask_proportions: + subtask_proportions[name] = [] + subtask_proportions[name].append(duration / total_duration) - # Normalize to ensure sum = 1 (handles floating point errors and missing subtasks) - total = sum(subtask_proportions.values()) + if not subtask_proportions: + return {} + + # Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i) + avg_proportions = { + name: sum(props) / len(props) + for name, props in subtask_proportions.items() + } + + # Normalize to ensure sum = 1 + total = sum(avg_proportions.values()) if total > 0: - subtask_proportions = { - name: prop / total for name, prop in subtask_proportions.items() - } - else: - raise ValueError("Cannot compute temporal proportions: all proportions are zero. " - "Check that your dataset has valid subtask annotations with start/end times.") + avg_proportions = {name: prop / total for name, prop in avg_proportions.items()} - return subtask_proportions + return avg_proportions def compute_tau( diff --git a/tests/policies/test_sarm_utils.py b/tests/policies/test_sarm_utils.py index 2ae2b6468..cd625af0d 100644 --- a/tests/policies/test_sarm_utils.py +++ b/tests/policies/test_sarm_utils.py @@ -18,7 +18,7 @@ Tests for SARM utility functions. Tests the implementation of SARM paper formulas: -- Formula (1): compute_priors - dataset-level temporal proportions +- Formula (1): compute_temporal_proportions - dataset-level temporal proportions - Formula (2): compute_tau, compute_cumulative_progress - progress labels """ @@ -26,15 +26,31 @@ import pytest import numpy as np import torch +from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp from lerobot.policies.sarm.sarm_utils import ( - compute_priors, + compute_temporal_proportions, compute_tau, compute_cumulative_progress_batch, ) +def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation: + """Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec).""" + return SubtaskAnnotation( + subtasks=[ + Subtask( + name=name, + timestamps=Timestamp( + start=f"{start // 60:02d}:{start % 60:02d}", + end=f"{end // 60:02d}:{end % 60:02d}" + ) + ) + for name, start, end in subtasks + ] + ) -class TestComputePriors: - """Tests for compute_priors (SARM Paper Formula 1). + +class TestComputeTemporalProportions: + """Tests for compute_temporal_proportions (SARM Paper Formula 1). Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i) @@ -45,65 +61,49 @@ class TestComputePriors: def test_basic_two_trajectories_equal_proportions(self): """Test with two trajectories that have equal proportions.""" # Both trajectories: subtask1 = 50%, subtask2 = 50% - subtask_durations = { - 'subtask1': [50, 100], # durations - 'subtask2': [50, 100], + # Traj 1: T=100s, subtask1=50s, subtask2=50s + # Traj 2: T=200s, subtask1=100s, subtask2=100s + annotations = { + 0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]), + 1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]), } - trajectory_lengths = { - 'subtask1': [100, 200], - 'subtask2': [100, 200], - } - subtask_names = ['subtask1', 'subtask2'] - result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + result = compute_temporal_proportions(annotations) # Both should be 0.5 assert abs(result['subtask1'] - 0.5) < 1e-6 assert abs(result['subtask2'] - 0.5) < 1e-6 def test_paper_example_different_from_avg_durations(self): - """Test that compute_priors differs from naive average duration approach. + """Test that compute_temporal_proportions differs from naive average duration approach. This is the key test showing the difference between: - Paper formula: average of (L_i,k / T_i) - Naive approach: mean(L_i,k) / sum(mean(L_i,j)) """ - # Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2) - # Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8) - subtask_durations = { - 'subtask1': [80, 40], - 'subtask2': [20, 160], + # Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2) + # Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8) + annotations = { + 0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]), + 1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]), } - trajectory_lengths = { - 'subtask1': [100, 200], - 'subtask2': [100, 200], - } - subtask_names = ['subtask1', 'subtask2'] - result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + result = compute_temporal_proportions(annotations) # Paper formula: # ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5 # ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5 assert abs(result['subtask1'] - 0.5) < 1e-6 assert abs(result['subtask2'] - 0.5) < 1e-6 - def test_single_trajectory(self): """Test with a single trajectory.""" - subtask_durations = { - 'reach': [30], - 'grasp': [20], - 'lift': [50], + # T=100s, reach=30s, grasp=20s, lift=50s + annotations = { + 0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]), } - trajectory_lengths = { - 'reach': [100], - 'grasp': [100], - 'lift': [100], - } - subtask_names = ['grasp', 'lift', 'reach'] # sorted order - result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + result = compute_temporal_proportions(annotations) assert abs(result['reach'] - 0.3) < 1e-6 assert abs(result['grasp'] - 0.2) < 1e-6 @@ -111,49 +111,35 @@ class TestComputePriors: def test_sum_to_one(self): """Test that proportions always sum to 1.""" - subtask_durations = { - 'a': [10, 20, 30], - 'b': [40, 50, 60], - 'c': [50, 30, 10], + # Three episodes with varying proportions + annotations = { + 0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5 + 1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3 + 2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1 } - trajectory_lengths = { - 'a': [100, 100, 100], - 'b': [100, 100, 100], - 'c': [100, 100, 100], - } - subtask_names = ['a', 'b', 'c'] - result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + result = compute_temporal_proportions(annotations) total = sum(result.values()) assert abs(total - 1.0) < 1e-6 - def test_empty_subtask_names_raises(self): - """Test that empty subtask_names raises an error.""" - with pytest.raises(ValueError, match="subtask_names cannot be empty"): - compute_priors({}, {}, []) + def test_empty_annotations_returns_empty(self): + """Test that empty annotations returns empty dict.""" + result = compute_temporal_proportions({}) + assert result == {} - def test_missing_subtask_gets_zero_before_normalization(self): - """Test handling of subtasks that appear in some but not all trajectories.""" - # subtask1 appears in both, subtask2 only in first - subtask_durations = { - 'subtask1': [50, 100], - 'subtask2': [50], # only in first trajectory + def test_uniform_proportions(self): + """Test with uniform proportions across subtasks.""" + # Each subtask takes 25% of each episode + annotations = { + 0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]), + 1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]), } - trajectory_lengths = { - 'subtask1': [100, 200], - 'subtask2': [100], - } - subtask_names = ['subtask1', 'subtask2'] - result = compute_priors(subtask_durations, trajectory_lengths, subtask_names) + result = compute_temporal_proportions(annotations) - # subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5 - # subtask2: 50/100 = 0.5 (only one occurrence) - # After normalization: both should be 0.5 - assert result['subtask1'] > 0 - assert result['subtask2'] > 0 - assert abs(sum(result.values()) - 1.0) < 1e-6 + for name in ['a', 'b', 'c', 'd']: + assert abs(result[name] - 0.25) < 1e-6 class TestComputeTau: