add tests, implement formula 1,2 correctly and cleanup

This commit is contained in:
Pepijn
2025-11-27 14:04:01 +01:00
parent 3ed0425d2c
commit f2ad86831d
7 changed files with 861 additions and 274 deletions
+392
View File
@@ -0,0 +1,392 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for SARM utility functions.
Tests the implementation of SARM paper formulas:
- Formula (1): compute_priors - dataset-level temporal proportions
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
"""
import pytest
import numpy as np
import torch
from lerobot.policies.sarm.sarm_utils import (
compute_priors,
compute_tau,
compute_cumulative_progress_batch,
)
class TestComputePriors:
"""Tests for compute_priors (SARM Paper Formula 1).
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
Key insight: This averages the PROPORTION of each subtask within each trajectory,
giving equal weight to all trajectories regardless of absolute length.
"""
def test_basic_two_trajectories_equal_proportions(self):
"""Test with two trajectories that have equal proportions."""
# Both trajectories: subtask1 = 50%, subtask2 = 50%
subtask_durations = {
'subtask1': [50, 100], # durations
'subtask2': [50, 100],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Both should be 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_paper_example_different_from_avg_durations(self):
"""Test that compute_priors differs from naive average duration approach.
This is the key test showing the difference between:
- Paper formula: average of (L_i,k / T_i)
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
"""
# Episode 1: T=100, subtask1=80, subtask2=20 (proportions: 0.8, 0.2)
# Episode 2: T=200, subtask1=40, subtask2=160 (proportions: 0.2, 0.8)
subtask_durations = {
'subtask1': [80, 40],
'subtask2': [20, 160],
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100, 200],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# Paper formula:
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
assert abs(result['subtask1'] - 0.5) < 1e-6
assert abs(result['subtask2'] - 0.5) < 1e-6
def test_single_trajectory(self):
"""Test with a single trajectory."""
subtask_durations = {
'reach': [30],
'grasp': [20],
'lift': [50],
}
trajectory_lengths = {
'reach': [100],
'grasp': [100],
'lift': [100],
}
subtask_names = ['grasp', 'lift', 'reach'] # sorted order
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
assert abs(result['reach'] - 0.3) < 1e-6
assert abs(result['grasp'] - 0.2) < 1e-6
assert abs(result['lift'] - 0.5) < 1e-6
def test_sum_to_one(self):
"""Test that proportions always sum to 1."""
subtask_durations = {
'a': [10, 20, 30],
'b': [40, 50, 60],
'c': [50, 30, 10],
}
trajectory_lengths = {
'a': [100, 100, 100],
'b': [100, 100, 100],
'c': [100, 100, 100],
}
subtask_names = ['a', 'b', 'c']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
total = sum(result.values())
assert abs(total - 1.0) < 1e-6
def test_empty_subtask_names_raises(self):
"""Test that empty subtask_names raises an error."""
with pytest.raises(ValueError, match="subtask_names cannot be empty"):
compute_priors({}, {}, [])
def test_missing_subtask_gets_zero_before_normalization(self):
"""Test handling of subtasks that appear in some but not all trajectories."""
# subtask1 appears in both, subtask2 only in first
subtask_durations = {
'subtask1': [50, 100],
'subtask2': [50], # only in first trajectory
}
trajectory_lengths = {
'subtask1': [100, 200],
'subtask2': [100],
}
subtask_names = ['subtask1', 'subtask2']
result = compute_priors(subtask_durations, trajectory_lengths, subtask_names)
# subtask1: (50/100 + 100/200) / 2 = (0.5 + 0.5) / 2 = 0.5
# subtask2: 50/100 = 0.5 (only one occurrence)
# After normalization: both should be 0.5
assert result['subtask1'] > 0
assert result['subtask2'] > 0
assert abs(sum(result.values()) - 1.0) < 1e-6
class TestComputeTau:
"""Tests for compute_tau (within-subtask progress).
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
"""
def test_at_start(self):
"""τ should be 0 at subtask start."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_at_end(self):
"""τ should be 1 at subtask end."""
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_at_middle(self):
"""τ should be 0.5 at subtask midpoint."""
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
assert abs(tau - 0.5) < 1e-6
def test_quarter_progress(self):
"""Test τ at 25% through subtask."""
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
assert abs(tau - 0.25) < 1e-6
def test_zero_duration_subtask(self):
"""τ should be 1.0 for zero-duration subtask."""
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
assert tau == 1.0
def test_clamps_below_zero(self):
"""τ should be clamped to 0 if frame is before subtask."""
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
assert tau == 0.0
def test_clamps_above_one(self):
"""τ should be clamped to 1 if frame is after subtask."""
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
assert tau == 1.0
def test_float_inputs(self):
"""Test with float frame indices (from interpolation)."""
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
expected = (25.5 - 10.0) / (50.0 - 10.0)
assert abs(tau - expected) < 1e-6
class TestComputeCumulativeProgressBatchScalar:
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
Formula: y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
"""
def test_first_subtask_start(self):
"""y should be 0 at start of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
assert y == 0.0
def test_first_subtask_end(self):
"""y should equal ᾱ_1 at end of first subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_start(self):
"""y should equal P_1 at start of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.3) < 1e-6
def test_second_subtask_end(self):
"""y should equal P_2 at end of second subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
def test_third_subtask_end(self):
"""y should be 1.0 at end of last subtask."""
proportions = [0.3, 0.5, 0.2]
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
assert abs(y - 1.0) < 1e-6
def test_midpoint_of_subtask(self):
"""Test progress at midpoint of a subtask."""
proportions = [0.4, 0.6]
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
assert abs(y - 0.2) < 1e-6
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
assert abs(y - 0.7) < 1e-6
def test_uniform_proportions(self):
"""Test with uniform proportions."""
proportions = [0.25, 0.25, 0.25, 0.25]
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
for i in range(4):
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
expected = (i + 1) * 0.25
assert abs(y - expected) < 1e-6
class TestComputeCumulativeProgressBatchTensor:
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
def test_tensor_matches_scalar_version(self):
"""Test that tensor version matches scalar version."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
test_cases = [
(0.0, 0), # start of subtask 0
(1.0, 0), # end of subtask 0
(0.0, 1), # start of subtask 1
(0.5, 1), # middle of subtask 1
(1.0, 2), # end of subtask 2
]
for tau_val, stage_idx in test_cases:
# Scalar version
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
# Tensor version (single element)
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
stages = torch.tensor([[stage_idx]]) # (1, 1)
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
assert abs(result[0, 0, 0].item() - expected) < 1e-6
def test_batch_processing(self):
"""Test batch processing with multiple samples."""
proportions = [0.4, 0.6]
alpha = torch.tensor(proportions, dtype=torch.float32)
cumulative = torch.zeros(3, dtype=torch.float32)
cumulative[1:] = torch.cumsum(alpha, dim=0)
# Batch of 2 samples, sequence length 3
tau = torch.tensor([
[[0.0], [0.5], [1.0]], # sample 1
[[0.0], [0.5], [1.0]], # sample 2
])
stages = torch.tensor([
[0, 0, 0], # sample 1: all in subtask 0
[1, 1, 1], # sample 2: all in subtask 1
])
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
def test_auto_compute_cumulative_prior(self):
"""Test that cumulative_prior is auto-computed when not provided."""
proportions = [0.3, 0.5, 0.2]
alpha = torch.tensor(proportions, dtype=torch.float32)
tau = torch.tensor([[[0.5]]])
stages = torch.tensor([[1]])
# Without cumulative_prior (should auto-compute)
result = compute_cumulative_progress_batch(tau, stages, alpha)
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
class TestEndToEndProgressLabeling:
"""End-to-end tests for progress label computation."""
def test_consistent_semantic_meaning(self):
"""Test that same subtask completion maps to same progress across trajectories.
This is the key semantic property: "end of subtask 1" should always
mean the same progress value regardless of trajectory speed.
"""
proportions = [0.3, 0.5, 0.2]
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
tau_fast = compute_tau(30, 0, 30) # = 1.0
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
tau_slow = compute_tau(90, 0, 90) # = 1.0
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
# Both should map to same progress (0.3 = end of subtask 1)
assert abs(y_fast - y_slow) < 1e-6
assert abs(y_fast - 0.3) < 1e-6
def test_monotonic_within_subtask(self):
"""Test that progress is monotonically increasing within a subtask."""
proportions = [0.4, 0.6]
prev_y = -1
for tau in np.linspace(0, 1, 11):
y = compute_cumulative_progress_batch(tau, 0, proportions)
assert y > prev_y or (tau == 0 and y == 0)
prev_y = y
def test_continuous_across_subtasks(self):
"""Test that progress is continuous at subtask boundaries."""
proportions = [0.3, 0.5, 0.2]
# End of subtask 0 (tau=1.0)
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
# Start of subtask 1 (tau=0.0)
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
# Should be equal (P_1 = 0.3)
assert abs(y_end_0 - y_start_1) < 1e-6
# End of subtask 1
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
# Start of subtask 2
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
# Should be equal (P_2 = 0.8)
assert abs(y_end_1 - y_start_2) < 1e-6