mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
simplify and cleanup code and move compute_temporal_proportions to utils
This commit is contained in:
@@ -72,13 +72,11 @@ import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing as mp
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -86,24 +84,8 @@ from rich.tree import Tree
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration, AutoProcessor
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Pydantic Models for SARM-style Annotation
|
||||
class Timestamp(BaseModel):
|
||||
"""Timestamp in MM:SS or SS format"""
|
||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||
|
||||
|
||||
class Subtask(BaseModel):
|
||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||
timestamps: Timestamp
|
||||
|
||||
|
||||
class SubtaskAnnotation(BaseModel):
|
||||
"""Complete annotation for a robot manipulation episode"""
|
||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import compute_temporal_proportions
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
|
||||
def create_sarm_prompt(subtask_list: list[str]) -> str:
|
||||
"""
|
||||
@@ -769,59 +751,6 @@ def worker_process_episodes(
|
||||
return annotations
|
||||
|
||||
|
||||
def compute_temporal_proportions(annotations: dict[int, SubtaskAnnotation], fps: int = 30) -> dict[str, float]:
|
||||
"""
|
||||
Compute average temporal proportion for each subtask across all episodes.
|
||||
This is the key insight from SARM - use semantic subtasks instead of frame indices.
|
||||
"""
|
||||
# Collect all proportions per subtask
|
||||
subtask_proportions = {}
|
||||
|
||||
for annotation in annotations.values():
|
||||
# Calculate total episode duration
|
||||
total_duration = 0
|
||||
durations = {}
|
||||
|
||||
for subtask in annotation.subtasks:
|
||||
# Parse timestamps
|
||||
start_parts = subtask.timestamps.start.split(":")
|
||||
end_parts = subtask.timestamps.end.split(":")
|
||||
|
||||
if len(start_parts) == 2:
|
||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1])
|
||||
else:
|
||||
start_seconds = int(start_parts[0])
|
||||
|
||||
if len(end_parts) == 2:
|
||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1])
|
||||
else:
|
||||
end_seconds = int(end_parts[0])
|
||||
|
||||
duration = end_seconds - start_seconds
|
||||
durations[subtask.name] = duration
|
||||
total_duration += duration
|
||||
|
||||
# Calculate proportions for this episode
|
||||
if total_duration > 0:
|
||||
for name, duration in durations.items():
|
||||
if name not in subtask_proportions:
|
||||
subtask_proportions[name] = []
|
||||
subtask_proportions[name].append(duration / total_duration)
|
||||
|
||||
# Average across episodes
|
||||
avg_proportions = {
|
||||
name: sum(props) / len(props)
|
||||
for name, props in subtask_proportions.items()
|
||||
}
|
||||
|
||||
# Normalize to sum to 1.0
|
||||
total = sum(avg_proportions.values())
|
||||
if total > 0:
|
||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||
|
||||
return avg_proportions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SARM-style subtask annotation using local GPU (Qwen3-VL)",
|
||||
@@ -1185,4 +1114,3 @@ Performance Tips:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from transformers import CLIPModel, CLIPProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_priors, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,6 +21,7 @@ from PIL import Image
|
||||
import pandas as pd
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.processor import (
|
||||
@@ -119,7 +118,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
num_frames: int,
|
||||
is_batch: bool,
|
||||
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute episode metadata for all samples.
|
||||
|
||||
@@ -140,10 +138,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
absolute_indices_list.append(abs_indices)
|
||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||
|
||||
if is_batch:
|
||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||
else:
|
||||
return absolute_indices_list[0], remaining_lengths[0], episode_lengths[0]
|
||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||
|
||||
def _compute_stage_and_progress_for_frame(
|
||||
self,
|
||||
@@ -189,7 +184,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
tau, stage_idx, temporal_proportions_list
|
||||
)
|
||||
|
||||
return stage_idx, cumulative_progress
|
||||
|
||||
# No matching subtask found
|
||||
@@ -288,15 +282,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None
|
||||
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine sequence length
|
||||
if video_features is not None and video_features.dim() >= 2:
|
||||
seq_len = video_features.shape[1] if is_batch else video_features.shape[0]
|
||||
seq_len = video_features.shape[1]
|
||||
else:
|
||||
seq_len = 1
|
||||
|
||||
@@ -315,24 +307,15 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
all_stage_labels.append(result[0])
|
||||
all_progress_targets.append(result[1])
|
||||
|
||||
if is_batch:
|
||||
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
||||
return all_stage_labels[0], all_progress_targets[0]
|
||||
return torch.stack(all_stage_labels, dim=0), torch.stack(all_progress_targets, dim=0)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images, text, and normalize states in the transition."""
|
||||
from lerobot.processor.core import TransitionKey
|
||||
|
||||
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("Observation must be a dictionary")
|
||||
|
||||
# Encode images with CLIP
|
||||
image = observation.get(self.image_key)
|
||||
if image is None:
|
||||
raise ValueError(f"Image not found in observation for key: {self.image_key}")
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
@@ -342,10 +325,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
if state_data is None:
|
||||
state_data = observation.get("state") or observation.get("observation.state")
|
||||
if state_data is None:
|
||||
raise ValueError(f"State data not found in observation (expected '{state_key}', 'state', or 'observation.state')")
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
@@ -355,8 +334,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if not isinstance(comp_data, dict):
|
||||
raise ValueError("COMPLEMENTARY_DATA must be a dictionary")
|
||||
|
||||
# Get task description from dataset (complementary_data["task"])
|
||||
task = comp_data.get('task')
|
||||
@@ -378,18 +355,17 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# Compute episode metadata if dataset_meta is available
|
||||
if self.dataset_meta is not None:
|
||||
is_batch = isinstance(frame_index, torch.Tensor) and frame_index.numel() > 1
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine number of frames from video features
|
||||
if video_features.dim() >= 2:
|
||||
num_frames = video_features.shape[1] if is_batch else video_features.shape[0]
|
||||
num_frames = video_features.shape[1]
|
||||
else:
|
||||
num_frames = 1
|
||||
|
||||
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
||||
frame_indices, episode_indices, num_frames, is_batch
|
||||
frame_indices, episode_indices, num_frames
|
||||
)
|
||||
observation['absolute_frame_indices'] = abs_indices
|
||||
observation['remaining_length'] = remaining
|
||||
@@ -412,24 +388,14 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape:
|
||||
- (B, C, H, W) for single frames, or
|
||||
- (B, T, C, H, W) for temporal sequences
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 512) or (B, T, 512)
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
# Check if we have temporal dimension
|
||||
has_temporal = len(images.shape) == 5
|
||||
|
||||
if has_temporal: # Shape: (B, T, C, H, W)
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
elif len(images.shape) == 4: # Shape: (B, C, H, W)
|
||||
batch_size = images.shape[0]
|
||||
seq_length = 1
|
||||
else:
|
||||
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
# Convert to list of PIL images
|
||||
num_frames = images.shape[0]
|
||||
@@ -470,9 +436,8 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Concatenate all embeddings
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
|
||||
# Reshape back if temporal
|
||||
if has_temporal:
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
# Reshape back
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@@ -495,10 +460,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
# Get text features from CLIP
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single text case
|
||||
if text_embedding.dim() == 1:
|
||||
text_embedding = text_embedding.unsqueeze(0)
|
||||
|
||||
# Replicate for batch (B, 512)
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
|
||||
@@ -14,25 +14,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility functions for SARM progress label computation.
|
||||
|
||||
Implements formulas from the SARM paper:
|
||||
- Formula (1): Compute dataset-level temporal proportions (priors) ᾱ_k
|
||||
- Formula (2): Compute normalized progress targets y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Pydantic Models for SARM-style Annotation
|
||||
class Timestamp(BaseModel):
|
||||
"""Timestamp in MM:SS or SS format"""
|
||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||
|
||||
|
||||
def compute_priors(
|
||||
subtask_durations_per_trajectory: dict[str, list[float]],
|
||||
trajectory_lengths: dict[str, list[float]],
|
||||
subtask_names: list[str],
|
||||
) -> dict[str, float]:
|
||||
class Subtask(BaseModel):
|
||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||
timestamps: Timestamp
|
||||
|
||||
|
||||
class SubtaskAnnotation(BaseModel):
|
||||
"""Complete annotation for a robot manipulation episode"""
|
||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||
|
||||
|
||||
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
|
||||
"""
|
||||
Compute dataset-level temporal proportions (priors) for each subtask.
|
||||
|
||||
@@ -40,61 +46,64 @@ def compute_priors(
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
where:
|
||||
- M is the number of trajectories
|
||||
- L_{i,k} is the length of subtask k in trajectory i
|
||||
- T_i is the total length of trajectory i
|
||||
- M is the number of trajectories (episodes)
|
||||
- L_{i,k} is the duration of subtask k in trajectory i
|
||||
- T_i is the total duration of trajectory i
|
||||
|
||||
This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of their absolute length.
|
||||
|
||||
Args:
|
||||
subtask_durations_per_trajectory: Dict mapping subtask name to list of
|
||||
(duration, trajectory_length) tuples for each occurrence
|
||||
trajectory_lengths: Dict mapping subtask name to list of trajectory lengths
|
||||
for each occurrence of that subtask
|
||||
subtask_names: Ordered list of subtask names
|
||||
annotations: Dict mapping episode index to SubtaskAnnotation object.
|
||||
Each annotation has a .subtasks list where each subtask has:
|
||||
- .name: subtask name
|
||||
- .timestamps.start: start time as "MM:SS" string
|
||||
- .timestamps.end: end time as "MM:SS" string
|
||||
fps: Frames per second (unused, kept for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dict mapping subtask name to its temporal proportion (ᾱ_k)
|
||||
Dict mapping subtask name to its temporal proportion (ᾱ_k).
|
||||
Proportions are normalized to sum to 1.0.
|
||||
"""
|
||||
if not subtask_names:
|
||||
raise ValueError("subtask_names cannot be empty")
|
||||
subtask_proportions: dict[str, list[float]] = {}
|
||||
|
||||
# Compute proportion per occurrence: L_{i,k} / T_i
|
||||
subtask_proportions = {}
|
||||
for name in subtask_names:
|
||||
if name in subtask_durations_per_trajectory and name in trajectory_lengths:
|
||||
durations = subtask_durations_per_trajectory[name]
|
||||
traj_lengths = trajectory_lengths[name]
|
||||
for annotation in annotations.values():
|
||||
total_duration = 0
|
||||
durations: dict[str, int] = {}
|
||||
|
||||
if len(durations) != len(traj_lengths):
|
||||
raise ValueError(
|
||||
f"Mismatch in lengths for subtask '{name}': "
|
||||
f"{len(durations)} durations vs {len(traj_lengths)} trajectory lengths"
|
||||
)
|
||||
for subtask in annotation.subtasks:
|
||||
start_parts = subtask.timestamps.start.split(":")
|
||||
end_parts = subtask.timestamps.end.split(":")
|
||||
|
||||
# Compute L_{i,k} / T_i for each occurrence
|
||||
proportions = []
|
||||
for duration, traj_len in zip(durations, traj_lengths):
|
||||
if traj_len > 0:
|
||||
proportions.append(duration / traj_len)
|
||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
|
||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
|
||||
|
||||
# Average across all occurrences: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
subtask_proportions[name] = np.mean(proportions) if proportions else 0.0
|
||||
else:
|
||||
subtask_proportions[name] = 0.0
|
||||
duration = end_seconds - start_seconds
|
||||
durations[subtask.name] = duration
|
||||
total_duration += duration
|
||||
|
||||
# Normalize to ensure sum = 1 (handles floating point errors and missing subtasks)
|
||||
total = sum(subtask_proportions.values())
|
||||
# Calculate L_{i,k} / T_i for each subtask in this trajectory
|
||||
if total_duration > 0:
|
||||
for name, duration in durations.items():
|
||||
if name not in subtask_proportions:
|
||||
subtask_proportions[name] = []
|
||||
subtask_proportions[name].append(duration / total_duration)
|
||||
|
||||
if not subtask_proportions:
|
||||
return {}
|
||||
|
||||
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
avg_proportions = {
|
||||
name: sum(props) / len(props)
|
||||
for name, props in subtask_proportions.items()
|
||||
}
|
||||
|
||||
# Normalize to ensure sum = 1
|
||||
total = sum(avg_proportions.values())
|
||||
if total > 0:
|
||||
subtask_proportions = {
|
||||
name: prop / total for name, prop in subtask_proportions.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("Cannot compute temporal proportions: all proportions are zero. "
|
||||
"Check that your dataset has valid subtask annotations with start/end times.")
|
||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||
|
||||
return subtask_proportions
|
||||
return avg_proportions
|
||||
|
||||
|
||||
def compute_tau(
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
Tests for SARM utility functions.
|
||||
|
||||
Tests the implementation of SARM paper formulas:
|
||||
- Formula (1): compute_priors - dataset-level temporal proportions
|
||||
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
|
||||
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
||||
"""
|
||||
|
||||
@@ -26,15 +26,31 @@ import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
compute_priors,
|
||||
compute_temporal_proportions,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
|
||||
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||
return SubtaskAnnotation(
|
||||
subtasks=[
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(
|
||||
start=f"{start // 60:02d}:{start % 60:02d}",
|
||||
end=f"{end // 60:02d}:{end % 60:02d}"
|
||||
)
|
||||
)
|
||||
for name, start, end in subtasks
|
||||
]
|
||||
)
|
||||
|
||||
class TestComputePriors:
|
||||
"""Tests for compute_priors (SARM Paper Formula 1).
|
||||
|
||||
class TestComputeTemporalProportions:
|
||||
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||
|
||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
@@ -45,42 +61,34 @@ class TestComputePriors:
|
||||
def test_basic_two_trajectories_equal_proportions(self):
|
||||
"""Test with two trajectories that have equal proportions."""
|
||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100], # durations
|
||||
'subtask2': [50, 100],
|
||||
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Both should be 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
def test_paper_example_different_from_avg_durations(self):
|
||||
"""Test that compute_priors differs from naive average duration approach.
|
||||
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||
|
||||
This is the key test showing the difference between:
|
||||
- Paper formula: average of (L_i,k / T_i)
|
||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||
"""
|
||||
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
|
||||
subtask_durations = {
|
||||
'subtask1': [80, 40],
|
||||
'subtask2': [20, 160],
|
||||
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Paper formula:
|
||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||
@@ -88,22 +96,14 @@ class TestComputePriors:
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
|
||||
def test_single_trajectory(self):
|
||||
"""Test with a single trajectory."""
|
||||
subtask_durations = {
|
||||
'reach': [30],
|
||||
'grasp': [20],
|
||||
'lift': [50],
|
||||
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||
annotations = {
|
||||
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'reach': [100],
|
||||
'grasp': [100],
|
||||
'lift': [100],
|
||||
}
|
||||
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
assert abs(result['reach'] - 0.3) < 1e-6
|
||||
assert abs(result['grasp'] - 0.2) < 1e-6
|
||||
@@ -111,49 +111,35 @@ class TestComputePriors:
|
||||
|
||||
def test_sum_to_one(self):
|
||||
"""Test that proportions always sum to 1."""
|
||||
subtask_durations = {
|
||||
'a': [10, 20, 30],
|
||||
'b': [40, 50, 60],
|
||||
'c': [50, 30, 10],
|
||||
# Three episodes with varying proportions
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
|
||||
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
|
||||
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'a': [100, 100, 100],
|
||||
'b': [100, 100, 100],
|
||||
'c': [100, 100, 100],
|
||||
}
|
||||
subtask_names = ['a', 'b', 'c']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
total = sum(result.values())
|
||||
assert abs(total - 1.0) < 1e-6
|
||||
|
||||
def test_empty_subtask_names_raises(self):
|
||||
"""Test that empty subtask_names raises an error."""
|
||||
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
|
||||
compute_priors({}, {}, [])
|
||||
def test_empty_annotations_returns_empty(self):
|
||||
"""Test that empty annotations returns empty dict."""
|
||||
result = compute_temporal_proportions({})
|
||||
assert result == {}
|
||||
|
||||
def test_missing_subtask_gets_zero_before_normalization(self):
|
||||
"""Test handling of subtasks that appear in some but not all trajectories."""
|
||||
# subtask1 appears in both, subtask2 only in first
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100],
|
||||
'subtask2': [50], # only in first trajectory
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions across subtasks."""
|
||||
# Each subtask takes 25% of each episode
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
|
||||
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
|
||||
# subtask2: 50/100 = 0.5 (only one occurrence)
|
||||
# After normalization: both should be 0.5
|
||||
assert result['subtask1'] > 0
|
||||
assert result['subtask2'] > 0
|
||||
assert abs(sum(result.values()) - 1.0) < 1e-6
|
||||
for name in ['a', 'b', 'c', 'd']:
|
||||
assert abs(result[name] - 0.25) < 1e-6
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
|
||||
Reference in New Issue
Block a user