mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
simplify and cleanup code and move compute_temporal_proportions to utils
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
Tests for SARM utility functions.
|
||||
|
||||
Tests the implementation of SARM paper formulas:
|
||||
- Formula (1): compute_priors - dataset-level temporal proportions
|
||||
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
|
||||
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
||||
"""
|
||||
|
||||
@@ -26,15 +26,31 @@ import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
compute_priors,
|
||||
compute_temporal_proportions,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
|
||||
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||
return SubtaskAnnotation(
|
||||
subtasks=[
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(
|
||||
start=f"{start // 60:02d}:{start % 60:02d}",
|
||||
end=f"{end // 60:02d}:{end % 60:02d}"
|
||||
)
|
||||
)
|
||||
for name, start, end in subtasks
|
||||
]
|
||||
)
|
||||
|
||||
class TestComputePriors:
|
||||
"""Tests for compute_priors (SARM Paper Formula 1).
|
||||
|
||||
class TestComputeTemporalProportions:
|
||||
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||
|
||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
@@ -45,65 +61,49 @@ class TestComputePriors:
|
||||
def test_basic_two_trajectories_equal_proportions(self):
|
||||
"""Test with two trajectories that have equal proportions."""
|
||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100], # durations
|
||||
'subtask2': [50, 100],
|
||||
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Both should be 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
def test_paper_example_different_from_avg_durations(self):
|
||||
"""Test that compute_priors differs from naive average duration approach.
|
||||
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||
|
||||
This is the key test showing the difference between:
|
||||
- Paper formula: average of (L_i,k / T_i)
|
||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||
"""
|
||||
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
|
||||
subtask_durations = {
|
||||
'subtask1': [80, 40],
|
||||
'subtask2': [20, 160],
|
||||
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100, 200],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Paper formula:
|
||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
|
||||
def test_single_trajectory(self):
|
||||
"""Test with a single trajectory."""
|
||||
subtask_durations = {
|
||||
'reach': [30],
|
||||
'grasp': [20],
|
||||
'lift': [50],
|
||||
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||
annotations = {
|
||||
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'reach': [100],
|
||||
'grasp': [100],
|
||||
'lift': [100],
|
||||
}
|
||||
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
assert abs(result['reach'] - 0.3) < 1e-6
|
||||
assert abs(result['grasp'] - 0.2) < 1e-6
|
||||
@@ -111,49 +111,35 @@ class TestComputePriors:
|
||||
|
||||
def test_sum_to_one(self):
|
||||
"""Test that proportions always sum to 1."""
|
||||
subtask_durations = {
|
||||
'a': [10, 20, 30],
|
||||
'b': [40, 50, 60],
|
||||
'c': [50, 30, 10],
|
||||
# Three episodes with varying proportions
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
|
||||
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
|
||||
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'a': [100, 100, 100],
|
||||
'b': [100, 100, 100],
|
||||
'c': [100, 100, 100],
|
||||
}
|
||||
subtask_names = ['a', 'b', 'c']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
total = sum(result.values())
|
||||
assert abs(total - 1.0) < 1e-6
|
||||
|
||||
def test_empty_subtask_names_raises(self):
|
||||
"""Test that empty subtask_names raises an error."""
|
||||
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
|
||||
compute_priors({}, {}, [])
|
||||
def test_empty_annotations_returns_empty(self):
|
||||
"""Test that empty annotations returns empty dict."""
|
||||
result = compute_temporal_proportions({})
|
||||
assert result == {}
|
||||
|
||||
def test_missing_subtask_gets_zero_before_normalization(self):
|
||||
"""Test handling of subtasks that appear in some but not all trajectories."""
|
||||
# subtask1 appears in both, subtask2 only in first
|
||||
subtask_durations = {
|
||||
'subtask1': [50, 100],
|
||||
'subtask2': [50], # only in first trajectory
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions across subtasks."""
|
||||
# Each subtask takes 25% of each episode
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
|
||||
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
|
||||
}
|
||||
trajectory_lengths = {
|
||||
'subtask1': [100, 200],
|
||||
'subtask2': [100],
|
||||
}
|
||||
subtask_names = ['subtask1', 'subtask2']
|
||||
|
||||
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
|
||||
# subtask2: 50/100 = 0.5 (only one occurrence)
|
||||
# After normalization: both should be 0.5
|
||||
assert result['subtask1'] > 0
|
||||
assert result['subtask2'] > 0
|
||||
assert abs(sum(result.values()) - 1.0) < 1e-6
|
||||
for name in ['a', 'b', 'c', 'd']:
|
||||
assert abs(result[name] - 0.25) < 1e-6
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
|
||||
Reference in New Issue
Block a user